llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
-
# All rights reserved.
|
|
9
|
-
#
|
|
10
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
11
|
-
# top-level folder for each specific model found within the models/ directory at
|
|
12
|
-
# the top-level of this source tree.
|
|
13
|
-
|
|
14
|
-
import collections
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def get_negative_inf_value(dtype):
|
|
20
|
-
return torch.finfo(dtype).min
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def to_2tuple(x):
|
|
24
|
-
if isinstance(x, collections.abc.Iterable):
|
|
25
|
-
return x
|
|
26
|
-
return (x, x)
|
|
Binary file
|
|
@@ -1,316 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
# type: ignore
|
|
8
|
-
import os
|
|
9
|
-
from typing import Any, cast
|
|
10
|
-
|
|
11
|
-
import torch
|
|
12
|
-
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
|
13
|
-
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
|
14
|
-
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
|
15
|
-
from torch import Tensor, nn
|
|
16
|
-
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
|
17
|
-
|
|
18
|
-
from ...datatypes import QuantizationMode
|
|
19
|
-
from ...quantize_impls import (
|
|
20
|
-
Fp8ScaledWeights,
|
|
21
|
-
ffn_swiglu,
|
|
22
|
-
load_fp8,
|
|
23
|
-
quantize_fp8,
|
|
24
|
-
)
|
|
25
|
-
from ..model import Transformer, TransformerBlock
|
|
26
|
-
from ..multimodal.model import CrossAttentionTransformer
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def swiglu_wrapper(
|
|
30
|
-
self,
|
|
31
|
-
x: Tensor,
|
|
32
|
-
):
|
|
33
|
-
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
|
34
|
-
return reduce_from_model_parallel_region(out)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def convert_to_quantized_model(
|
|
38
|
-
model: Transformer | CrossAttentionTransformer,
|
|
39
|
-
checkpoint_dir: str,
|
|
40
|
-
quantization_mode: str | None = None,
|
|
41
|
-
fp8_activation_scale_ub: float | None = 1200.0,
|
|
42
|
-
device: torch.device | None = None,
|
|
43
|
-
) -> Transformer | CrossAttentionTransformer:
|
|
44
|
-
if quantization_mode == QuantizationMode.fp8_mixed:
|
|
45
|
-
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
|
46
|
-
elif quantization_mode == QuantizationMode.int4_mixed:
|
|
47
|
-
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
|
|
48
|
-
else:
|
|
49
|
-
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def convert_to_fp8_quantized_model(
|
|
53
|
-
model: Transformer,
|
|
54
|
-
checkpoint_dir: str,
|
|
55
|
-
fp8_activation_scale_ub: float | None = 1200.0,
|
|
56
|
-
device: torch.device | None = None,
|
|
57
|
-
) -> Transformer:
|
|
58
|
-
# Move weights to GPU with quantization
|
|
59
|
-
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
|
60
|
-
if os.path.isfile(fp8_scales_path):
|
|
61
|
-
print("Loading fp8 scales...")
|
|
62
|
-
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
|
63
|
-
|
|
64
|
-
for _, block in model.named_modules():
|
|
65
|
-
if isinstance(block, TransformerBlock):
|
|
66
|
-
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
|
67
|
-
continue
|
|
68
|
-
|
|
69
|
-
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
|
70
|
-
for key in ("w1", "w3", "w2"):
|
|
71
|
-
param = getattr(block.feed_forward, key)
|
|
72
|
-
param.weight = load_fp8(
|
|
73
|
-
param.weight,
|
|
74
|
-
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
|
|
75
|
-
fp8_activation_scale_ub,
|
|
76
|
-
)
|
|
77
|
-
else:
|
|
78
|
-
print("Quantizing fp8 weights from bf16...")
|
|
79
|
-
for _, block in model.named_modules():
|
|
80
|
-
if isinstance(block, TransformerBlock):
|
|
81
|
-
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
|
82
|
-
continue
|
|
83
|
-
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
|
|
84
|
-
for key in ("w1", "w3", "w2"):
|
|
85
|
-
param = getattr(block.feed_forward, key)
|
|
86
|
-
param.weight = quantize_fp8(
|
|
87
|
-
param.weight,
|
|
88
|
-
fp8_activation_scale_ub,
|
|
89
|
-
output_device=device,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
for _, parameter in model.named_parameters():
|
|
93
|
-
if not isinstance(parameter, Fp8ScaledWeights):
|
|
94
|
-
parameter.data = parameter.to(device=device)
|
|
95
|
-
return model
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|
99
|
-
"""
|
|
100
|
-
Int8DynActInt4WeightLinear with LoRA adaptor.
|
|
101
|
-
|
|
102
|
-
Args:
|
|
103
|
-
in_features: Number of input features.
|
|
104
|
-
out_features: Number of output features.
|
|
105
|
-
bias: Whether to use bias.
|
|
106
|
-
device: Device to use.
|
|
107
|
-
group_size: Group size for quantization.
|
|
108
|
-
precision: Precision of quantization.
|
|
109
|
-
scales_precision: Precision of scales.
|
|
110
|
-
lora_rank: Rank of LoRA adaptor.
|
|
111
|
-
lora_scale: Scale of LoRA adaptor.
|
|
112
|
-
"""
|
|
113
|
-
|
|
114
|
-
def __init__(
|
|
115
|
-
self,
|
|
116
|
-
in_features: int,
|
|
117
|
-
out_features: int,
|
|
118
|
-
bias=False,
|
|
119
|
-
device=None,
|
|
120
|
-
# quantization parameters
|
|
121
|
-
group_size: int = 256,
|
|
122
|
-
precision: torch.dtype = torch.float32,
|
|
123
|
-
scales_precision: torch.dtype = torch.float32,
|
|
124
|
-
# LoRA parameters
|
|
125
|
-
lora_rank: int | None = None,
|
|
126
|
-
lora_scale: float | None = None,
|
|
127
|
-
) -> None:
|
|
128
|
-
super().__init__(
|
|
129
|
-
in_features,
|
|
130
|
-
out_features,
|
|
131
|
-
bias=bias,
|
|
132
|
-
device=device,
|
|
133
|
-
groupsize=group_size,
|
|
134
|
-
precision=precision,
|
|
135
|
-
scales_precision=scales_precision,
|
|
136
|
-
)
|
|
137
|
-
self.lora_scale: float | None = None
|
|
138
|
-
self.adaptor: nn.Sequential | None = None
|
|
139
|
-
if lora_rank is not None:
|
|
140
|
-
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
|
141
|
-
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
|
142
|
-
self.adaptor = nn.Sequential()
|
|
143
|
-
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
|
144
|
-
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
|
145
|
-
self.lora_scale = lora_scale
|
|
146
|
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
147
|
-
|
|
148
|
-
def load_hook(
|
|
149
|
-
self,
|
|
150
|
-
state_dict: dict[str, Any],
|
|
151
|
-
prefix: str,
|
|
152
|
-
local_metadata: dict[str, Any],
|
|
153
|
-
strict: bool,
|
|
154
|
-
missing_keys: list[str],
|
|
155
|
-
unexpected_keys: list[str],
|
|
156
|
-
error_msgs: list[str],
|
|
157
|
-
) -> None:
|
|
158
|
-
"""A hook to load the quantized weights from the state dict."""
|
|
159
|
-
if prefix + "zeros" not in state_dict:
|
|
160
|
-
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
|
161
|
-
assert prefix + "scales" in state_dict
|
|
162
|
-
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
|
|
163
|
-
|
|
164
|
-
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
|
165
|
-
module_out = super().forward(input_)
|
|
166
|
-
if self.adaptor is not None:
|
|
167
|
-
adaptor_out = self.adaptor(input_) * self.lora_scale
|
|
168
|
-
return module_out + adaptor_out
|
|
169
|
-
return module_out
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
class Int8WeightEmbedding(torch.nn.Embedding):
|
|
173
|
-
"""An embedding layer to load int8 weights.
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
num_embeddings: Number of embeddings.
|
|
177
|
-
embedding_dim: Embedding dimension.
|
|
178
|
-
padding_idx: Padding index.
|
|
179
|
-
"""
|
|
180
|
-
|
|
181
|
-
def __init__(
|
|
182
|
-
self,
|
|
183
|
-
num_embeddings: int,
|
|
184
|
-
embedding_dim: int,
|
|
185
|
-
padding_idx: int,
|
|
186
|
-
device=None,
|
|
187
|
-
) -> None:
|
|
188
|
-
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
|
|
189
|
-
|
|
190
|
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
191
|
-
|
|
192
|
-
def load_hook(
|
|
193
|
-
self,
|
|
194
|
-
state_dict: dict[str, Any],
|
|
195
|
-
prefix: str,
|
|
196
|
-
local_metadata: dict[str, Any],
|
|
197
|
-
strict: bool,
|
|
198
|
-
missing_keys: list[str],
|
|
199
|
-
unexpected_keys: list[str],
|
|
200
|
-
error_msgs: list[str],
|
|
201
|
-
) -> None:
|
|
202
|
-
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
|
203
|
-
weights = state_dict.pop(prefix + "weight")
|
|
204
|
-
scales = state_dict.pop(prefix + "scales")
|
|
205
|
-
state_dict[prefix + "weight"] = weights * scales
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
class Int8WeightLinear(torch.nn.Linear):
|
|
209
|
-
"""A linear layer to load int8 weights.
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
in_features: Number of input features.
|
|
213
|
-
out_features: Number of output features.
|
|
214
|
-
bias: Whether to use bias.
|
|
215
|
-
"""
|
|
216
|
-
|
|
217
|
-
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
|
|
218
|
-
super().__init__(in_features, out_features, bias, device=device)
|
|
219
|
-
|
|
220
|
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
221
|
-
|
|
222
|
-
def load_hook(
|
|
223
|
-
self,
|
|
224
|
-
state_dict: dict[str, Any],
|
|
225
|
-
prefix: str,
|
|
226
|
-
local_metadata: dict[str, Any],
|
|
227
|
-
strict: bool,
|
|
228
|
-
missing_keys: list[str],
|
|
229
|
-
unexpected_keys: list[str],
|
|
230
|
-
error_msgs: list[str],
|
|
231
|
-
) -> None:
|
|
232
|
-
"""A hook to load the quantized linear weight and scales from the state dict."""
|
|
233
|
-
weights = state_dict.pop(prefix + "weight")
|
|
234
|
-
scales = state_dict.pop(prefix + "scales")
|
|
235
|
-
state_dict[prefix + "weight"] = weights * scales
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def _prepare_model_int4_weight_int8_dynamic_activation(
|
|
239
|
-
model: torch.nn.Module,
|
|
240
|
-
group_size: int,
|
|
241
|
-
lora_rank: int | None,
|
|
242
|
-
lora_scale: float | None,
|
|
243
|
-
):
|
|
244
|
-
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
|
245
|
-
|
|
246
|
-
Note that the weights of embedding and output layers are quantized to int8.
|
|
247
|
-
"""
|
|
248
|
-
device = None
|
|
249
|
-
for module_name, module in model.named_children():
|
|
250
|
-
if module_name == "output":
|
|
251
|
-
quantized_module = Int8WeightLinear(
|
|
252
|
-
in_features=module.in_features,
|
|
253
|
-
out_features=module.out_features,
|
|
254
|
-
bias=module.bias,
|
|
255
|
-
device=device,
|
|
256
|
-
)
|
|
257
|
-
del module
|
|
258
|
-
setattr(model, module_name, quantized_module)
|
|
259
|
-
elif module_name == "tok_embeddings":
|
|
260
|
-
quantized_module = Int8WeightEmbedding(
|
|
261
|
-
num_embeddings=module.num_embeddings,
|
|
262
|
-
embedding_dim=module.embedding_dim,
|
|
263
|
-
padding_idx=module.padding_idx,
|
|
264
|
-
device=device,
|
|
265
|
-
)
|
|
266
|
-
del module
|
|
267
|
-
setattr(model, module_name, quantized_module)
|
|
268
|
-
elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
|
|
269
|
-
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
|
270
|
-
in_features=module.in_features,
|
|
271
|
-
out_features=module.out_features,
|
|
272
|
-
bias=False,
|
|
273
|
-
group_size=group_size,
|
|
274
|
-
lora_rank=lora_rank,
|
|
275
|
-
lora_scale=lora_scale,
|
|
276
|
-
device=device,
|
|
277
|
-
)
|
|
278
|
-
del module
|
|
279
|
-
setattr(model, module_name, quantized_module)
|
|
280
|
-
else:
|
|
281
|
-
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
|
|
282
|
-
|
|
283
|
-
return model
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
def convert_to_int4_quantized_model(
|
|
287
|
-
model: Transformer | CrossAttentionTransformer,
|
|
288
|
-
checkpoint_dir: str,
|
|
289
|
-
device: torch.device | None = None,
|
|
290
|
-
) -> Transformer | CrossAttentionTransformer:
|
|
291
|
-
"""Convert the model to int4 quantized model."""
|
|
292
|
-
model_args = model.params
|
|
293
|
-
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
|
294
|
-
quantization_args = model_args.quantization_args
|
|
295
|
-
if quantization_args.scheme is None:
|
|
296
|
-
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
|
|
297
|
-
|
|
298
|
-
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
|
299
|
-
raise NotImplementedError(
|
|
300
|
-
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
group_size = model_args.quantization_args.group_size
|
|
304
|
-
if group_size is None:
|
|
305
|
-
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
|
|
306
|
-
|
|
307
|
-
if model_args.lora_args is None:
|
|
308
|
-
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
|
309
|
-
lora_rank = None
|
|
310
|
-
lora_scale = None
|
|
311
|
-
else:
|
|
312
|
-
lora_rank = model_args.lora_args.rank
|
|
313
|
-
lora_scale = model_args.lora_args.scale
|
|
314
|
-
|
|
315
|
-
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
|
316
|
-
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
-
# All rights reserved.
|
|
9
|
-
#
|
|
10
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
11
|
-
# top-level folder for each specific model found within the models/ directory at
|
|
12
|
-
# the top-level of this source tree.
|