sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
2
|
+
|
3
|
+
from types import MappingProxyType
|
4
|
+
from typing import List, Mapping, Tuple, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.utils import is_cuda
|
9
|
+
|
10
|
+
_is_cuda = is_cuda()
|
11
|
+
|
12
|
+
if _is_cuda:
|
13
|
+
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
14
|
+
else:
|
15
|
+
from vllm import _custom_ops as vllm_ops
|
16
|
+
|
17
|
+
|
18
|
+
def is_fp8_fnuz() -> bool:
|
19
|
+
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
20
|
+
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
21
|
+
|
22
|
+
|
23
|
+
def is_layer_skipped(
|
24
|
+
prefix: str,
|
25
|
+
ignored_layers: List[str],
|
26
|
+
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
27
|
+
) -> bool:
|
28
|
+
# prefix: model.layers.0.self_attn.q_proj
|
29
|
+
# proj_name: q_proj
|
30
|
+
proj_name = prefix.split(".")[-1]
|
31
|
+
|
32
|
+
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
33
|
+
# in the safetensors checkpoint. So, we convert the name
|
34
|
+
# from the fused version to unfused + check to make sure that
|
35
|
+
# each shard of the fused layer has the same scheme.
|
36
|
+
if proj_name in fused_mapping:
|
37
|
+
shard_prefixes = [
|
38
|
+
prefix.replace(proj_name, shard_proj_name)
|
39
|
+
for shard_proj_name in fused_mapping[proj_name]
|
40
|
+
]
|
41
|
+
|
42
|
+
is_skipped = None
|
43
|
+
for shard_prefix in shard_prefixes:
|
44
|
+
is_shard_skipped = shard_prefix in ignored_layers
|
45
|
+
|
46
|
+
if is_skipped is None:
|
47
|
+
is_skipped = is_shard_skipped
|
48
|
+
elif is_shard_skipped != is_skipped:
|
49
|
+
raise ValueError(
|
50
|
+
f"Detected some but not all shards of {prefix} "
|
51
|
+
"are quantized. All shards of fused layers "
|
52
|
+
"to have the same precision."
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
is_skipped = prefix in ignored_layers
|
56
|
+
|
57
|
+
assert is_skipped is not None
|
58
|
+
return is_skipped
|
59
|
+
|
60
|
+
|
61
|
+
def per_tensor_dequantize(
|
62
|
+
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
63
|
+
) -> torch.Tensor:
|
64
|
+
fake_qweight = tensor.to(torch.float16)
|
65
|
+
dq_weight = fake_qweight * inv_scale
|
66
|
+
return dq_weight
|
67
|
+
|
68
|
+
|
69
|
+
def all_close_1d(x: torch.Tensor) -> bool:
|
70
|
+
assert len(x.shape) == 1
|
71
|
+
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
72
|
+
|
73
|
+
|
74
|
+
def convert_to_channelwise(
|
75
|
+
weight_scale: torch.Tensor, logical_widths: List[int]
|
76
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
77
|
+
# Create channelwise buffer
|
78
|
+
weight_scale_channel = torch.empty(
|
79
|
+
(sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device
|
80
|
+
)
|
81
|
+
|
82
|
+
# Handle scalar tensor case: broadcast same scale to all channels
|
83
|
+
if weight_scale.dim() == 0:
|
84
|
+
weight_scale_channel.fill_(weight_scale.item())
|
85
|
+
return weight_scale_channel
|
86
|
+
|
87
|
+
# Expand each scale to match the size of each logical matrix.
|
88
|
+
start = 0
|
89
|
+
for idx, logical_width in enumerate(logical_widths):
|
90
|
+
end = start + logical_width
|
91
|
+
weight_scale_channel[start:end, :] = weight_scale[idx]
|
92
|
+
start = end
|
93
|
+
|
94
|
+
return weight_scale_channel
|
95
|
+
|
96
|
+
|
97
|
+
def requantize_with_max_scale(
|
98
|
+
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
|
99
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
100
|
+
# Max scale to be used for requanitzation.
|
101
|
+
max_w_scale = weight_scale.max()
|
102
|
+
|
103
|
+
# QKV / MLP is fused in the on disk checkpoint if any of the
|
104
|
+
# weight scales are still set to the default since we initialize
|
105
|
+
# N weight scales for N shards but we only load 1 weight scale
|
106
|
+
# from disk in this case. Skip requantization in this case (since)
|
107
|
+
# we already are quantized with the single scale.
|
108
|
+
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
109
|
+
unfused_module_in_checkpoint = (
|
110
|
+
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
111
|
+
)
|
112
|
+
|
113
|
+
# If unfused checkpoint, need requanize with the single scale.
|
114
|
+
if unfused_module_in_checkpoint:
|
115
|
+
start = 0
|
116
|
+
for idx, logical_width in enumerate(logical_widths):
|
117
|
+
end = start + logical_width
|
118
|
+
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
119
|
+
if _is_cuda:
|
120
|
+
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
121
|
+
else:
|
122
|
+
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
123
|
+
weight_dq, max_w_scale
|
124
|
+
)
|
125
|
+
start = end
|
126
|
+
|
127
|
+
return max_w_scale, weight
|
128
|
+
|
129
|
+
|
130
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
|
131
|
+
# Newly generated tensors need to replace existing tensors that are
|
132
|
+
# already registered as parameters by vLLM (and won't be freed)
|
133
|
+
def replace_parameter(
|
134
|
+
mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
|
135
|
+
) -> None:
|
136
|
+
|
137
|
+
old = getattr(mod, name)
|
138
|
+
if (
|
139
|
+
type(old) is type(new)
|
140
|
+
and old.dtype == new.dtype
|
141
|
+
and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
|
142
|
+
):
|
143
|
+
# If we can just update in-place to avoid re-registering
|
144
|
+
# can be faster if the underlying storage is the same
|
145
|
+
update_tensor_inplace(old, new)
|
146
|
+
else:
|
147
|
+
# Fallback re-register parameter, convert to Parameter if necessary
|
148
|
+
# this not only ensures we don't register a tensor as a parameter, but
|
149
|
+
# also ensures that all parameter subclasses get re-registered as
|
150
|
+
# parameters for `torch.compile` compatibility
|
151
|
+
if not isinstance(new, torch.nn.Parameter):
|
152
|
+
new = torch.nn.Parameter(new, requires_grad=False)
|
153
|
+
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|
9
9
|
QuantizationConfig,
|
10
10
|
QuantizeMethodBase,
|
11
11
|
)
|
12
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
12
13
|
from sglang.srt.layers.quantization.fp8_utils import (
|
13
14
|
apply_fp8_linear,
|
14
15
|
cutlass_fp8_supported,
|
16
|
+
input_to_float8,
|
15
17
|
normalize_e4m3fn_to_e4m3fnuz,
|
16
18
|
)
|
17
19
|
from sglang.srt.utils import is_hip
|
@@ -22,12 +24,24 @@ _is_hip = is_hip()
|
|
22
24
|
class W8A8Fp8Config(QuantizationConfig):
|
23
25
|
"""Config class for W8A8 FP8 Quantization.
|
24
26
|
|
25
|
-
|
26
|
-
-
|
27
|
+
Weight Quantization:
|
28
|
+
- Method: Static quantization
|
29
|
+
- Granularity: Per-channel
|
30
|
+
- Type: Symmetric
|
31
|
+
|
32
|
+
Activation Quantization:
|
33
|
+
- Method: Dynamic quantization
|
34
|
+
- Granularity: Per-token
|
35
|
+
- Type: Symmetric
|
36
|
+
|
37
|
+
Note:
|
38
|
+
- For models without offline quantization, weights will be quantized during model loading
|
39
|
+
- If CUTLASS is supported: Per-channel weight quantization is used
|
40
|
+
- If CUTLASS is not supported: Falls back to per-token weight quantization
|
27
41
|
"""
|
28
42
|
|
29
|
-
def __init__(self):
|
30
|
-
|
43
|
+
def __init__(self, is_checkpoint_fp8_serialized: bool = False):
|
44
|
+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
31
45
|
|
32
46
|
@classmethod
|
33
47
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
47
61
|
|
48
62
|
@classmethod
|
49
63
|
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
50
|
-
|
64
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
65
|
+
is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
|
66
|
+
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
|
51
67
|
|
52
68
|
def get_quant_method(
|
53
69
|
self,
|
@@ -72,13 +88,40 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
72
88
|
|
73
89
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
74
90
|
weight = layer.weight
|
75
|
-
|
76
|
-
if
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
91
|
+
|
92
|
+
if self.quantization_config.is_checkpoint_fp8_serialized:
|
93
|
+
weight_scale = layer.weight_scale.detach()
|
94
|
+
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
95
|
+
if _is_hip:
|
96
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
97
|
+
weight=weight, weight_scale=weight_scale
|
98
|
+
)
|
99
|
+
|
100
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
101
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
102
|
+
else:
|
103
|
+
# If checkpoint not offline quantized, quantize the weights with per-channel quantization.
|
104
|
+
if self.cutlass_fp8_supported:
|
105
|
+
# if cutlass supported, we use cutlass_scaled_mm
|
106
|
+
# which requires per-channel quantization on weight
|
107
|
+
qweight, weight_scale = per_token_group_quant_fp8(
|
108
|
+
layer.weight, layer.weight.shape[-1]
|
109
|
+
)
|
110
|
+
weight_scale = weight_scale.t().contiguous()
|
111
|
+
if _is_hip:
|
112
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
113
|
+
weight=weight, weight_scale=weight_scale
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
# if cutlass not supported, we fall back to use torch._scaled_mm
|
117
|
+
# which requires per tensor quantization on weight
|
118
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
119
|
+
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
120
|
+
|
121
|
+
# Update the layer with the new values.
|
122
|
+
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
123
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
124
|
+
layer.input_scale = None
|
82
125
|
|
83
126
|
def create_weights(
|
84
127
|
self,
|
@@ -90,6 +133,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
90
133
|
params_dtype: torch.dtype,
|
91
134
|
**extra_weight_attrs
|
92
135
|
):
|
136
|
+
weight_dtype = (
|
137
|
+
torch.float8_e4m3fn
|
138
|
+
if self.quantization_config.is_checkpoint_fp8_serialized
|
139
|
+
else params_dtype
|
140
|
+
)
|
93
141
|
|
94
142
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
95
143
|
self.logical_widths = output_partition_sizes
|
@@ -98,7 +146,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
98
146
|
data=torch.empty(
|
99
147
|
sum(output_partition_sizes),
|
100
148
|
input_size_per_partition,
|
101
|
-
dtype=
|
149
|
+
dtype=weight_dtype,
|
102
150
|
),
|
103
151
|
input_dim=1,
|
104
152
|
output_dim=0,
|
@@ -106,12 +154,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
106
154
|
)
|
107
155
|
layer.register_parameter("weight", weight)
|
108
156
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
157
|
+
if self.quantization_config.is_checkpoint_fp8_serialized:
|
158
|
+
weight_scale = ChannelQuantScaleParameter(
|
159
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
160
|
+
output_dim=0,
|
161
|
+
weight_loader=weight_loader,
|
162
|
+
)
|
163
|
+
layer.register_parameter("weight_scale", weight_scale)
|
164
|
+
else:
|
165
|
+
layer.weight_scale = None
|
115
166
|
|
116
167
|
def apply(
|
117
168
|
self,
|
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
|
-
from vllm import _custom_ops as ops
|
10
9
|
|
11
10
|
from sglang.srt.custom_op import CustomOp
|
12
11
|
from sglang.srt.utils import is_cuda_available
|
@@ -14,6 +13,8 @@ from sglang.srt.utils import is_cuda_available
|
|
14
13
|
_is_cuda_available = is_cuda_available()
|
15
14
|
if _is_cuda_available:
|
16
15
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
16
|
+
else:
|
17
|
+
from vllm import _custom_ops as ops
|
17
18
|
|
18
19
|
|
19
20
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -147,7 +148,7 @@ class RotaryEmbedding(CustomOp):
|
|
147
148
|
key: torch.Tensor,
|
148
149
|
offsets: Optional[torch.Tensor] = None,
|
149
150
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
150
|
-
if _is_cuda_available:
|
151
|
+
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
|
151
152
|
apply_rope_with_cos_sin_cache_inplace(
|
152
153
|
positions=positions,
|
153
154
|
query=query,
|
@@ -168,76 +169,6 @@ class RotaryEmbedding(CustomOp):
|
|
168
169
|
)
|
169
170
|
return query, key
|
170
171
|
|
171
|
-
def forward_xpu(
|
172
|
-
self,
|
173
|
-
positions: torch.Tensor,
|
174
|
-
query: torch.Tensor,
|
175
|
-
key: torch.Tensor,
|
176
|
-
offsets: Optional[torch.Tensor] = None,
|
177
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
178
|
-
from vllm._ipex_ops import ipex_ops as ops
|
179
|
-
|
180
|
-
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
|
181
|
-
ops.rotary_embedding(
|
182
|
-
positions,
|
183
|
-
query,
|
184
|
-
key,
|
185
|
-
self.head_size,
|
186
|
-
self.cos_sin_cache,
|
187
|
-
self.is_neox_style,
|
188
|
-
)
|
189
|
-
return query, key
|
190
|
-
|
191
|
-
def forward_hpu(
|
192
|
-
self,
|
193
|
-
positions: torch.Tensor,
|
194
|
-
query: torch.Tensor,
|
195
|
-
key: torch.Tensor,
|
196
|
-
offsets: Optional[torch.Tensor] = None,
|
197
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
198
|
-
from habana_frameworks.torch.hpex.kernels import (
|
199
|
-
RotaryPosEmbeddingMode,
|
200
|
-
apply_rotary_pos_emb,
|
201
|
-
)
|
202
|
-
|
203
|
-
positions = positions.flatten()
|
204
|
-
if offsets is not None:
|
205
|
-
positions = positions + offsets
|
206
|
-
num_tokens = positions.shape[0]
|
207
|
-
cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
|
208
|
-
cos, sin = cos_sin.chunk(2, dim=-1)
|
209
|
-
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
210
|
-
# to query hidden dimension, so the original tensors need to be
|
211
|
-
# expanded
|
212
|
-
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
213
|
-
# and expansion of cos/sin tensors via concatenation
|
214
|
-
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
|
215
|
-
# and expansion of cos/sin tensors via repeat_interleave
|
216
|
-
rope_mode: RotaryPosEmbeddingMode
|
217
|
-
if self.is_neox_style:
|
218
|
-
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
219
|
-
cos = torch.cat((cos, cos), dim=-1)
|
220
|
-
sin = torch.cat((sin, sin), dim=-1)
|
221
|
-
else:
|
222
|
-
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
223
|
-
sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
|
224
|
-
cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
|
225
|
-
|
226
|
-
query_shape = query.shape
|
227
|
-
query = query.view(num_tokens, -1, self.head_size)
|
228
|
-
query_rot = query[..., : self.rotary_dim]
|
229
|
-
query_pass = query[..., self.rotary_dim :]
|
230
|
-
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
231
|
-
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
232
|
-
|
233
|
-
key_shape = key.shape
|
234
|
-
key = key.view(num_tokens, -1, self.head_size)
|
235
|
-
key_rot = key[..., : self.rotary_dim]
|
236
|
-
key_pass = key[..., self.rotary_dim :]
|
237
|
-
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
238
|
-
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
239
|
-
return query, key
|
240
|
-
|
241
172
|
def extra_repr(self) -> str:
|
242
173
|
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
243
174
|
s += f", max_position_embeddings={self.max_position_embeddings}"
|
@@ -510,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|
510
441
|
):
|
511
442
|
super().__init__()
|
512
443
|
|
513
|
-
if rotary_dim != head_size:
|
514
|
-
raise ValueError(
|
515
|
-
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
516
|
-
rotary_dim != head_size ({rotary_dim}!={head_size})."
|
517
|
-
)
|
518
444
|
if is_neox_style is False:
|
519
445
|
raise ValueError(
|
520
446
|
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
521
447
|
)
|
522
448
|
|
449
|
+
self.rotary_dim = rotary_dim
|
523
450
|
self.head_size = head_size
|
524
451
|
self.max_position_embeddings = max_position_embeddings
|
525
452
|
self.original_max_position_embeddings = original_max_position_embeddings
|
@@ -568,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|
568
495
|
* (
|
569
496
|
self.base
|
570
497
|
** (
|
571
|
-
torch.arange(0, self.
|
572
|
-
/ self.
|
498
|
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
499
|
+
/ self.rotary_dim
|
573
500
|
)
|
574
501
|
)
|
575
502
|
)
|
@@ -618,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|
618
545
|
cos = cos.repeat(1, 2).unsqueeze(-2)
|
619
546
|
sin = sin.repeat(1, 2).unsqueeze(-2)
|
620
547
|
|
621
|
-
|
622
|
-
|
548
|
+
query_rot = query[..., : self.rotary_dim]
|
549
|
+
query_pass = query[..., self.rotary_dim :]
|
550
|
+
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
|
551
|
+
query = torch.cat((query_rot, query_pass), dim=-1)
|
552
|
+
|
553
|
+
key_rot = key[..., : self.rotary_dim]
|
554
|
+
key_pass = key[..., self.rotary_dim :]
|
555
|
+
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
|
556
|
+
key = torch.cat((key_rot, key_pass), dim=-1)
|
623
557
|
|
624
558
|
return query.flatten(-2), key.flatten(-2)
|
625
559
|
|
@@ -879,8 +813,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
879
813
|
spatial_merge_size: int,
|
880
814
|
context_len: int = 0,
|
881
815
|
seq_len: Optional[int] = None,
|
816
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
817
|
+
tokens_per_second: Optional[int] = None,
|
882
818
|
) -> Tuple[List[List[int]], int]:
|
883
|
-
"""
|
819
|
+
"""
|
820
|
+
Get mrope input positions and delta value.
|
821
|
+
|
822
|
+
:arg
|
823
|
+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
824
|
+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
825
|
+
|
826
|
+
"""
|
884
827
|
|
885
828
|
if isinstance(image_grid_thw, torch.Tensor):
|
886
829
|
image_grid_thw = image_grid_thw.tolist()
|
@@ -917,6 +860,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
917
860
|
)
|
918
861
|
image_index += 1
|
919
862
|
remain_images -= 1
|
863
|
+
second_per_grid_t = 0
|
920
864
|
ed = ed_image
|
921
865
|
else:
|
922
866
|
t, h, w = (
|
@@ -924,6 +868,10 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
924
868
|
video_grid_thw[video_index][1],
|
925
869
|
video_grid_thw[video_index][2],
|
926
870
|
)
|
871
|
+
if second_per_grid_ts is not None:
|
872
|
+
second_per_grid_t = second_per_grid_ts[video_index]
|
873
|
+
else:
|
874
|
+
second_per_grid_t = 1.0
|
927
875
|
video_index += 1
|
928
876
|
remain_videos -= 1
|
929
877
|
ed = ed_video
|
@@ -940,11 +888,11 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
940
888
|
)
|
941
889
|
|
942
890
|
t_index = (
|
943
|
-
torch.arange(llm_grid_t)
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
891
|
+
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
892
|
+
* second_per_grid_t
|
893
|
+
* tokens_per_second
|
894
|
+
).flatten()
|
895
|
+
|
948
896
|
h_index = (
|
949
897
|
torch.arange(llm_grid_h)
|
950
898
|
.view(1, -1, 1)
|
@@ -1172,6 +1120,37 @@ def get_rope(
|
|
1172
1120
|
return rotary_emb
|
1173
1121
|
|
1174
1122
|
|
1123
|
+
# Copied from transformers
|
1124
|
+
def rotate_half(x):
|
1125
|
+
"""Rotates half the hidden dims of the input."""
|
1126
|
+
x1 = x[..., : x.shape[-1] // 2]
|
1127
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
1128
|
+
return torch.cat((-x2, x1), dim=-1)
|
1129
|
+
|
1130
|
+
|
1131
|
+
def apply_rotary_pos_emb(
|
1132
|
+
q: torch.Tensor,
|
1133
|
+
k: torch.Tensor,
|
1134
|
+
cos: torch.Tensor,
|
1135
|
+
sin: torch.Tensor,
|
1136
|
+
unsqueeze_dim=1,
|
1137
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1138
|
+
orig_q_dtype = q.dtype
|
1139
|
+
orig_k_dtype = k.dtype
|
1140
|
+
q, k = q.float(), k.float()
|
1141
|
+
|
1142
|
+
# embedding is performed in float
|
1143
|
+
cos = cos.unsqueeze(unsqueeze_dim).float()
|
1144
|
+
sin = sin.unsqueeze(unsqueeze_dim).float()
|
1145
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
1146
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
1147
|
+
|
1148
|
+
q_embed = q_embed.to(orig_q_dtype)
|
1149
|
+
k_embed = k_embed.to(orig_k_dtype)
|
1150
|
+
|
1151
|
+
return q_embed, k_embed
|
1152
|
+
|
1153
|
+
|
1175
1154
|
def get_rope_cpu(
|
1176
1155
|
head_size: int,
|
1177
1156
|
rotary_dim: int,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -168,7 +168,7 @@ class Sampler(nn.Module):
|
|
168
168
|
group=self.tp_sync_group,
|
169
169
|
)
|
170
170
|
|
171
|
-
return batch_next_token_ids
|
171
|
+
return batch_next_token_ids
|
172
172
|
|
173
173
|
def _apply_custom_logit_processor(
|
174
174
|
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
sglang/srt/lora/layers.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
|
1
3
|
import torch
|
2
4
|
from torch import nn
|
3
5
|
|
@@ -38,8 +40,22 @@ class BaseLayerWithLoRA(nn.Module):
|
|
38
40
|
def set_lora_info(self, *args):
|
39
41
|
pass
|
40
42
|
|
43
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
44
|
+
pass
|
45
|
+
|
46
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
47
|
+
pass
|
48
|
+
|
41
49
|
|
42
50
|
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
51
|
+
"""
|
52
|
+
Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
|
53
|
+
|
54
|
+
Note: The current version does not yet implement the LoRA functionality.
|
55
|
+
This class behaves exactly the same as the base VocabParallelEmbedding.
|
56
|
+
Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
|
57
|
+
"""
|
58
|
+
|
43
59
|
def __init__(
|
44
60
|
self,
|
45
61
|
base_layer: VocabParallelEmbedding,
|
@@ -101,6 +117,16 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
101
117
|
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
102
118
|
return output, output_bias
|
103
119
|
|
120
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
121
|
+
return A
|
122
|
+
|
123
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
124
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
125
|
+
start_idx = tp_rank * shard_size
|
126
|
+
end_idx = (tp_rank + 1) * shard_size
|
127
|
+
B = B[start_idx:end_idx, :]
|
128
|
+
return B
|
129
|
+
|
104
130
|
|
105
131
|
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
106
132
|
def __init__(
|
@@ -120,6 +146,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
120
146
|
self.set_lora = True
|
121
147
|
self.A_buffer_gate_up = A_buffer
|
122
148
|
if self.lora_backend.fuse_stacked_lora_b:
|
149
|
+
# TODO: avoid using contiguous() in GPU.
|
123
150
|
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
124
151
|
self.B_buffer_gate_up = torch.cat(
|
125
152
|
(B_buffer[0], B_buffer[1]), dim=-2
|
@@ -142,6 +169,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
142
169
|
else base_output + lora_output * self.scaling
|
143
170
|
)
|
144
171
|
|
172
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
173
|
+
return A
|
174
|
+
|
175
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
176
|
+
# Since the outputs for both gate and up are identical, we use a random one.
|
177
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
178
|
+
start_idx = tp_rank * shard_size
|
179
|
+
end_idx = (tp_rank + 1) * shard_size
|
180
|
+
return B[:, start_idx:end_idx, :]
|
181
|
+
|
145
182
|
|
146
183
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
147
184
|
def init__(
|
@@ -210,6 +247,27 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
210
247
|
else base_output + lora_output * self.scaling
|
211
248
|
)
|
212
249
|
|
250
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
251
|
+
return A
|
252
|
+
|
253
|
+
def slice_lora_b_weights(
|
254
|
+
self, B: List[torch.Tensor], tp_rank: int
|
255
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
256
|
+
B_q, B_kv = B
|
257
|
+
base_layer = self.base_layer
|
258
|
+
q_proj_shard_size = base_layer.q_proj_shard_size
|
259
|
+
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
260
|
+
num_kv_head_replicas = base_layer.num_kv_head_replicas
|
261
|
+
|
262
|
+
q_start_idx = q_proj_shard_size * tp_rank
|
263
|
+
q_end_idx = q_start_idx + q_proj_shard_size
|
264
|
+
|
265
|
+
kv_shard_id = tp_rank // num_kv_head_replicas
|
266
|
+
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
267
|
+
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
268
|
+
|
269
|
+
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
|
270
|
+
|
213
271
|
|
214
272
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
215
273
|
def __init__(
|
@@ -274,6 +332,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
274
332
|
output_bias = self.base_layer.bias
|
275
333
|
return output, output_bias
|
276
334
|
|
335
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
336
|
+
shard_size = self.base_layer.input_size_per_partition
|
337
|
+
start_idx = tp_rank * shard_size
|
338
|
+
end_idx = (tp_rank + 1) * shard_size
|
339
|
+
A = A[:, start_idx:end_idx].contiguous()
|
340
|
+
return A
|
341
|
+
|
342
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
343
|
+
return B
|
344
|
+
|
277
345
|
|
278
346
|
def get_lora_layer(
|
279
347
|
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
|
sglang/srt/lora/lora.py
CHANGED
@@ -39,16 +39,9 @@ class LoRALayer(nn.Module):
|
|
39
39
|
super().__init__()
|
40
40
|
self.config: LoRAConfig = config
|
41
41
|
self.base_hf_config: AutoConfig = base_hf_config
|
42
|
-
self.weights: Dict[str, torch.Tensor] = {}
|
43
|
-
self.weight_gpu: Dict[str, torch.Tensor] = {}
|
44
|
-
|
45
|
-
def load_to_gpu(self):
|
46
|
-
for name, weight in self.weights.items():
|
47
|
-
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
48
42
|
|
49
|
-
|
50
|
-
|
51
|
-
self.weight_gpu[name] = None
|
43
|
+
# lora weights in cpu. The weights are loaded from checkpoint.
|
44
|
+
self.weights: Dict[str, torch.Tensor] = {}
|
52
45
|
|
53
46
|
|
54
47
|
class LoRAAdapter(nn.Module):
|
@@ -77,19 +70,6 @@ class LoRAAdapter(nn.Module):
|
|
77
70
|
)
|
78
71
|
|
79
72
|
self.weights: Dict[str, torch.Tensor] = {}
|
80
|
-
self.weights_gpu: Dict[str, torch.Tensor] = {}
|
81
|
-
|
82
|
-
def load_to_gpu(self):
|
83
|
-
for name, weight in self.weights.items():
|
84
|
-
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
85
|
-
for layer in self.layers:
|
86
|
-
layer.load_to_gpu()
|
87
|
-
|
88
|
-
def offload_from_gpu(self):
|
89
|
-
for name, weight in self.weights.items():
|
90
|
-
self.weights_gpu[name] = None
|
91
|
-
for layer in self.layers:
|
92
|
-
layer.offload_from_gpu()
|
93
73
|
|
94
74
|
# initialize the LoRA weights to cpu
|
95
75
|
def initialize_weights(self):
|