llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,26 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- import collections
15
-
16
- import torch
17
-
18
-
19
- def get_negative_inf_value(dtype):
20
- return torch.finfo(dtype).min
21
-
22
-
23
- def to_2tuple(x):
24
- if isinstance(x, collections.abc.Iterable):
25
- return x
26
- return (x, x)
Binary file
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
@@ -1,316 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # type: ignore
8
- import os
9
- from typing import Any, cast
10
-
11
- import torch
12
- from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
13
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
14
- from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
15
- from torch import Tensor, nn
16
- from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
17
-
18
- from ...datatypes import QuantizationMode
19
- from ...quantize_impls import (
20
- Fp8ScaledWeights,
21
- ffn_swiglu,
22
- load_fp8,
23
- quantize_fp8,
24
- )
25
- from ..model import Transformer, TransformerBlock
26
- from ..multimodal.model import CrossAttentionTransformer
27
-
28
-
29
- def swiglu_wrapper(
30
- self,
31
- x: Tensor,
32
- ):
33
- out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
34
- return reduce_from_model_parallel_region(out)
35
-
36
-
37
- def convert_to_quantized_model(
38
- model: Transformer | CrossAttentionTransformer,
39
- checkpoint_dir: str,
40
- quantization_mode: str | None = None,
41
- fp8_activation_scale_ub: float | None = 1200.0,
42
- device: torch.device | None = None,
43
- ) -> Transformer | CrossAttentionTransformer:
44
- if quantization_mode == QuantizationMode.fp8_mixed:
45
- return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
46
- elif quantization_mode == QuantizationMode.int4_mixed:
47
- return convert_to_int4_quantized_model(model, checkpoint_dir, device)
48
- else:
49
- raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
50
-
51
-
52
- def convert_to_fp8_quantized_model(
53
- model: Transformer,
54
- checkpoint_dir: str,
55
- fp8_activation_scale_ub: float | None = 1200.0,
56
- device: torch.device | None = None,
57
- ) -> Transformer:
58
- # Move weights to GPU with quantization
59
- fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
60
- if os.path.isfile(fp8_scales_path):
61
- print("Loading fp8 scales...")
62
- fp8_scales = torch.load(fp8_scales_path, weights_only=True)
63
-
64
- for _, block in model.named_modules():
65
- if isinstance(block, TransformerBlock):
66
- if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
67
- continue
68
-
69
- block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
70
- for key in ("w1", "w3", "w2"):
71
- param = getattr(block.feed_forward, key)
72
- param.weight = load_fp8(
73
- param.weight,
74
- fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
75
- fp8_activation_scale_ub,
76
- )
77
- else:
78
- print("Quantizing fp8 weights from bf16...")
79
- for _, block in model.named_modules():
80
- if isinstance(block, TransformerBlock):
81
- if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
82
- continue
83
- block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
84
- for key in ("w1", "w3", "w2"):
85
- param = getattr(block.feed_forward, key)
86
- param.weight = quantize_fp8(
87
- param.weight,
88
- fp8_activation_scale_ub,
89
- output_device=device,
90
- )
91
-
92
- for _, parameter in model.named_parameters():
93
- if not isinstance(parameter, Fp8ScaledWeights):
94
- parameter.data = parameter.to(device=device)
95
- return model
96
-
97
-
98
- class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
99
- """
100
- Int8DynActInt4WeightLinear with LoRA adaptor.
101
-
102
- Args:
103
- in_features: Number of input features.
104
- out_features: Number of output features.
105
- bias: Whether to use bias.
106
- device: Device to use.
107
- group_size: Group size for quantization.
108
- precision: Precision of quantization.
109
- scales_precision: Precision of scales.
110
- lora_rank: Rank of LoRA adaptor.
111
- lora_scale: Scale of LoRA adaptor.
112
- """
113
-
114
- def __init__(
115
- self,
116
- in_features: int,
117
- out_features: int,
118
- bias=False,
119
- device=None,
120
- # quantization parameters
121
- group_size: int = 256,
122
- precision: torch.dtype = torch.float32,
123
- scales_precision: torch.dtype = torch.float32,
124
- # LoRA parameters
125
- lora_rank: int | None = None,
126
- lora_scale: float | None = None,
127
- ) -> None:
128
- super().__init__(
129
- in_features,
130
- out_features,
131
- bias=bias,
132
- device=device,
133
- groupsize=group_size,
134
- precision=precision,
135
- scales_precision=scales_precision,
136
- )
137
- self.lora_scale: float | None = None
138
- self.adaptor: nn.Sequential | None = None
139
- if lora_rank is not None:
140
- assert lora_scale is not None, "Please specify lora scale for LoRA."
141
- # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
142
- self.adaptor = nn.Sequential()
143
- self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
144
- self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
145
- self.lora_scale = lora_scale
146
- self._register_load_state_dict_pre_hook(self.load_hook)
147
-
148
- def load_hook(
149
- self,
150
- state_dict: dict[str, Any],
151
- prefix: str,
152
- local_metadata: dict[str, Any],
153
- strict: bool,
154
- missing_keys: list[str],
155
- unexpected_keys: list[str],
156
- error_msgs: list[str],
157
- ) -> None:
158
- """A hook to load the quantized weights from the state dict."""
159
- if prefix + "zeros" not in state_dict:
160
- # Zero-point may not be saved in the state dict. In this case, we assume it's zero.
161
- assert prefix + "scales" in state_dict
162
- state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
163
-
164
- def forward(self, input_: torch.Tensor) -> torch.Tensor:
165
- module_out = super().forward(input_)
166
- if self.adaptor is not None:
167
- adaptor_out = self.adaptor(input_) * self.lora_scale
168
- return module_out + adaptor_out
169
- return module_out
170
-
171
-
172
- class Int8WeightEmbedding(torch.nn.Embedding):
173
- """An embedding layer to load int8 weights.
174
-
175
- Args:
176
- num_embeddings: Number of embeddings.
177
- embedding_dim: Embedding dimension.
178
- padding_idx: Padding index.
179
- """
180
-
181
- def __init__(
182
- self,
183
- num_embeddings: int,
184
- embedding_dim: int,
185
- padding_idx: int,
186
- device=None,
187
- ) -> None:
188
- super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
189
-
190
- self._register_load_state_dict_pre_hook(self.load_hook)
191
-
192
- def load_hook(
193
- self,
194
- state_dict: dict[str, Any],
195
- prefix: str,
196
- local_metadata: dict[str, Any],
197
- strict: bool,
198
- missing_keys: list[str],
199
- unexpected_keys: list[str],
200
- error_msgs: list[str],
201
- ) -> None:
202
- """A hook to load the quantized embedding weight and scales from the state dict."""
203
- weights = state_dict.pop(prefix + "weight")
204
- scales = state_dict.pop(prefix + "scales")
205
- state_dict[prefix + "weight"] = weights * scales
206
-
207
-
208
- class Int8WeightLinear(torch.nn.Linear):
209
- """A linear layer to load int8 weights.
210
-
211
- Args:
212
- in_features: Number of input features.
213
- out_features: Number of output features.
214
- bias: Whether to use bias.
215
- """
216
-
217
- def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
218
- super().__init__(in_features, out_features, bias, device=device)
219
-
220
- self._register_load_state_dict_pre_hook(self.load_hook)
221
-
222
- def load_hook(
223
- self,
224
- state_dict: dict[str, Any],
225
- prefix: str,
226
- local_metadata: dict[str, Any],
227
- strict: bool,
228
- missing_keys: list[str],
229
- unexpected_keys: list[str],
230
- error_msgs: list[str],
231
- ) -> None:
232
- """A hook to load the quantized linear weight and scales from the state dict."""
233
- weights = state_dict.pop(prefix + "weight")
234
- scales = state_dict.pop(prefix + "scales")
235
- state_dict[prefix + "weight"] = weights * scales
236
-
237
-
238
- def _prepare_model_int4_weight_int8_dynamic_activation(
239
- model: torch.nn.Module,
240
- group_size: int,
241
- lora_rank: int | None,
242
- lora_scale: float | None,
243
- ):
244
- """Prepare the model for int4 weight and int8 dynamic activation quantization.
245
-
246
- Note that the weights of embedding and output layers are quantized to int8.
247
- """
248
- device = None
249
- for module_name, module in model.named_children():
250
- if module_name == "output":
251
- quantized_module = Int8WeightLinear(
252
- in_features=module.in_features,
253
- out_features=module.out_features,
254
- bias=module.bias,
255
- device=device,
256
- )
257
- del module
258
- setattr(model, module_name, quantized_module)
259
- elif module_name == "tok_embeddings":
260
- quantized_module = Int8WeightEmbedding(
261
- num_embeddings=module.num_embeddings,
262
- embedding_dim=module.embedding_dim,
263
- padding_idx=module.padding_idx,
264
- device=device,
265
- )
266
- del module
267
- setattr(model, module_name, quantized_module)
268
- elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
269
- quantized_module = Int8DynActInt4WeightLinearLoRA(
270
- in_features=module.in_features,
271
- out_features=module.out_features,
272
- bias=False,
273
- group_size=group_size,
274
- lora_rank=lora_rank,
275
- lora_scale=lora_scale,
276
- device=device,
277
- )
278
- del module
279
- setattr(model, module_name, quantized_module)
280
- else:
281
- _prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
282
-
283
- return model
284
-
285
-
286
- def convert_to_int4_quantized_model(
287
- model: Transformer | CrossAttentionTransformer,
288
- checkpoint_dir: str,
289
- device: torch.device | None = None,
290
- ) -> Transformer | CrossAttentionTransformer:
291
- """Convert the model to int4 quantized model."""
292
- model_args = model.params
293
- assert model_args.quantization_args is not None, "Quantization args must be specified."
294
- quantization_args = model_args.quantization_args
295
- if quantization_args.scheme is None:
296
- raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
297
-
298
- if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
299
- raise NotImplementedError(
300
- "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
301
- )
302
-
303
- group_size = model_args.quantization_args.group_size
304
- if group_size is None:
305
- raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
306
-
307
- if model_args.lora_args is None:
308
- # Certain quantized models (e.g., SpinQuant) may not have LoRA.
309
- lora_rank = None
310
- lora_scale = None
311
- else:
312
- lora_rank = model_args.lora_args.rank
313
- lora_scale = model_args.lora_args.scale
314
-
315
- _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
316
- return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.