ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +32 -0
- ai_edge_torch/_config.py +69 -0
- ai_edge_torch/_convert/__init__.py +14 -0
- ai_edge_torch/_convert/conversion.py +153 -0
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/_convert/converter.py +270 -0
- ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
- ai_edge_torch/_convert/signature.py +66 -0
- ai_edge_torch/_convert/test/__init__.py +14 -0
- ai_edge_torch/_convert/test/test_convert.py +558 -0
- ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
- ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/_convert/to_channel_last_io.py +92 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +496 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +140 -0
- ai_edge_torch/debug/test/test_search_model.py +51 -0
- ai_edge_torch/debug/utils.py +59 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/fx_pass_base.py +110 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
- ai_edge_torch/generative/examples/llama/llama.py +196 -0
- ai_edge_torch/generative/examples/llama/verify.py +88 -0
- ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
- ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
- ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
- ai_edge_torch/generative/examples/openelm/verify.py +71 -0
- ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
- ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
- ai_edge_torch/generative/examples/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/phi/phi2.py +107 -0
- ai_edge_torch/generative/examples/phi/phi3.py +219 -0
- ai_edge_torch/generative/examples/phi/verify.py +64 -0
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
- ai_edge_torch/generative/examples/qwen/verify.py +88 -0
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
- ai_edge_torch/generative/examples/smollm/verify.py +86 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
- ai_edge_torch/generative/examples/t5/t5.py +655 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
- ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
- ai_edge_torch/generative/fx_passes/__init__.py +30 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +399 -0
- ai_edge_torch/generative/layers/attention_utils.py +210 -0
- ai_edge_torch/generative/layers/builder.py +160 -0
- ai_edge_torch/generative/layers/feed_forward.py +120 -0
- ai_edge_torch/generative/layers/kv_cache.py +204 -0
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +238 -0
- ai_edge_torch/generative/layers/normalization.py +222 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
- ai_edge_torch/generative/layers/unet/builder.py +50 -0
- ai_edge_torch/generative/layers/unet/model_config.py +282 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +47 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_custom_dus.py +107 -0
- ai_edge_torch/generative/test/test_kv_cache.py +120 -0
- ai_edge_torch/generative/test/test_loader.py +83 -0
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion.py +191 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
- ai_edge_torch/generative/test/test_quantize.py +183 -0
- ai_edge_torch/generative/test/utils.py +82 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/converter.py +215 -0
- ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
- ai_edge_torch/generative/utilities/loader.py +398 -0
- ai_edge_torch/generative/utilities/model_builder.py +180 -0
- ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
- ai_edge_torch/generative/utilities/t5_loader.py +512 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +335 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
- ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
- ai_edge_torch/lowertools/__init__.py +18 -0
- ai_edge_torch/lowertools/_shim.py +86 -0
- ai_edge_torch/lowertools/common_utils.py +142 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
- ai_edge_torch/lowertools/test_utils.py +62 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
- ai_edge_torch/lowertools/translate_recipe.py +163 -0
- ai_edge_torch/model.py +177 -0
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +88 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +403 -0
- ai_edge_torch/odml_torch/export_utils.py +157 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
- ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +156 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
- ai_edge_torch/version.py +16 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,557 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""LoRA weights for generative models.
|
17
|
+
|
18
|
+
The current implementation support attention only lora. Additionally, we expect
|
19
|
+
lora weights for all projections within the attention module (i.e., Q, K, V, O).
|
20
|
+
"""
|
21
|
+
|
22
|
+
import dataclasses
|
23
|
+
from typing import Any, Callable, List, Optional, Tuple
|
24
|
+
|
25
|
+
from ai_edge_torch.generative.layers import model_config
|
26
|
+
import flatbuffers
|
27
|
+
import numpy as np
|
28
|
+
import safetensors
|
29
|
+
import torch
|
30
|
+
import torch.utils._pytree as pytree
|
31
|
+
|
32
|
+
from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
|
33
|
+
|
34
|
+
_TFLITE_SCHEMA_VERSION = 3
|
35
|
+
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
36
|
+
|
37
|
+
|
38
|
+
@dataclasses.dataclass
|
39
|
+
class LoRAWeight:
|
40
|
+
"""LoRA weight per projection. The weights are pre-transposed."""
|
41
|
+
|
42
|
+
a_prime: torch.Tensor
|
43
|
+
b_prime: torch.Tensor
|
44
|
+
|
45
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
46
|
+
if not isinstance(other, LoRAWeight):
|
47
|
+
return False
|
48
|
+
if self.a_prime.shape != other.a_prime.shape:
|
49
|
+
return False
|
50
|
+
if self.b_prime.shape != other.b_prime.shape:
|
51
|
+
return False
|
52
|
+
return torch.allclose(
|
53
|
+
self.a_prime, other.a_prime, rtol=rtol, atol=atol
|
54
|
+
) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
|
55
|
+
|
56
|
+
|
57
|
+
@dataclasses.dataclass
|
58
|
+
class AttentionLoRA:
|
59
|
+
"""LoRA weights for attention module."""
|
60
|
+
|
61
|
+
query: LoRAWeight
|
62
|
+
key: LoRAWeight
|
63
|
+
value: LoRAWeight
|
64
|
+
output: LoRAWeight
|
65
|
+
|
66
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
67
|
+
if not isinstance(other, AttentionLoRA):
|
68
|
+
return False
|
69
|
+
return (
|
70
|
+
self.query.__eq__(other.query, rtol=rtol, atol=atol)
|
71
|
+
and self.key.__eq__(other.key, rtol=rtol, atol=atol)
|
72
|
+
and self.value.__eq__(other.value, rtol=rtol, atol=atol)
|
73
|
+
and self.output.__eq__(other.output, rtol=rtol, atol=atol)
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
@dataclasses.dataclass
|
78
|
+
class LoRAEntry:
|
79
|
+
"""LoRA weights for a single layer."""
|
80
|
+
|
81
|
+
attention: AttentionLoRA
|
82
|
+
|
83
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
84
|
+
if not isinstance(other, LoRAEntry):
|
85
|
+
return False
|
86
|
+
return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
|
87
|
+
|
88
|
+
|
89
|
+
@dataclasses.dataclass
|
90
|
+
class LoRATensorNames:
|
91
|
+
"""Tensor names for LoRA weights."""
|
92
|
+
|
93
|
+
attn_query_w_a: str
|
94
|
+
attn_query_w_b: str
|
95
|
+
|
96
|
+
attn_key_w_a: str
|
97
|
+
attn_key_w_b: str
|
98
|
+
|
99
|
+
attn_value_w_a: str
|
100
|
+
attn_value_w_b: str
|
101
|
+
|
102
|
+
attn_output_w_a: str
|
103
|
+
attn_output_w_b: str
|
104
|
+
|
105
|
+
|
106
|
+
@dataclasses.dataclass
|
107
|
+
class LoRA:
|
108
|
+
"""LoRA weights for all modules."""
|
109
|
+
|
110
|
+
adapters: Tuple[LoRAEntry, ...]
|
111
|
+
|
112
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
113
|
+
if not isinstance(other, LoRA):
|
114
|
+
return False
|
115
|
+
if len(self.adapters) != len(other.adapters):
|
116
|
+
return False
|
117
|
+
return all(
|
118
|
+
adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
|
119
|
+
for adapter, other_adapter in zip(self.adapters, other.adapters)
|
120
|
+
)
|
121
|
+
|
122
|
+
def get_rank(self) -> int:
|
123
|
+
"""Returns the rank of the LoRA weights."""
|
124
|
+
return self.adapters[0].attention.query.a_prime.shape[1]
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def from_safetensors(
|
128
|
+
cls,
|
129
|
+
path: str,
|
130
|
+
scale: float,
|
131
|
+
config: model_config.ModelConfig,
|
132
|
+
lora_tensor_names: LoRATensorNames,
|
133
|
+
dtype: torch.dtype = torch.float32,
|
134
|
+
) -> "LoRA":
|
135
|
+
"""Creates LoRA weights from a Hugging Face model.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
path: Path to the model.
|
139
|
+
scale: Scale factor for the LoRA weights (applied only to one of the
|
140
|
+
projections). The scaling factor depnds on the training configuration.
|
141
|
+
The common values are either `lora_alpha / rank` or `lora_alpha /
|
142
|
+
sqrt(rank)`.
|
143
|
+
config: Model configuration.
|
144
|
+
lora_tensor_names: Tensor names for the LoRA weights.
|
145
|
+
dtype: Data type of the LoRA weights. Currently only float32 is supported.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
LoRA weights for all modules.
|
149
|
+
"""
|
150
|
+
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
151
|
+
adapters = []
|
152
|
+
for i in range(config.num_layers):
|
153
|
+
attention_lora = AttentionLoRA(
|
154
|
+
query=LoRAWeight(
|
155
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
|
156
|
+
.to(dtype)
|
157
|
+
.T
|
158
|
+
* scale,
|
159
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
|
160
|
+
.to(dtype)
|
161
|
+
.T,
|
162
|
+
),
|
163
|
+
key=LoRAWeight(
|
164
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
|
165
|
+
.to(dtype)
|
166
|
+
.T
|
167
|
+
* scale,
|
168
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
|
169
|
+
.to(dtype)
|
170
|
+
.T,
|
171
|
+
),
|
172
|
+
value=LoRAWeight(
|
173
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
|
174
|
+
.to(dtype)
|
175
|
+
.T
|
176
|
+
* scale,
|
177
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
|
178
|
+
.to(dtype)
|
179
|
+
.T,
|
180
|
+
),
|
181
|
+
output=LoRAWeight(
|
182
|
+
a_prime=f.get_tensor(
|
183
|
+
lora_tensor_names.attn_output_w_a.format(i)
|
184
|
+
)
|
185
|
+
.to(dtype)
|
186
|
+
.T
|
187
|
+
* scale,
|
188
|
+
b_prime=f.get_tensor(
|
189
|
+
lora_tensor_names.attn_output_w_b.format(i)
|
190
|
+
)
|
191
|
+
.to(dtype)
|
192
|
+
.T,
|
193
|
+
),
|
194
|
+
)
|
195
|
+
adapters.append(LoRAEntry(attention=attention_lora))
|
196
|
+
return cls(adapters=adapters)
|
197
|
+
|
198
|
+
@classmethod
|
199
|
+
def from_flatbuffers(
|
200
|
+
cls,
|
201
|
+
flatbuffer_model: bytearray,
|
202
|
+
dtype: torch.dtype = torch.float32,
|
203
|
+
) -> "LoRA":
|
204
|
+
"""Creates LoRA weights from FlatBuffers.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
flatbuffer_model: FlatBuffers model.
|
208
|
+
dtype: Data type of the LoRA weights.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
LoRA weights for all modules.
|
212
|
+
"""
|
213
|
+
model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
|
214
|
+
model = schema_fb.ModelT.InitFromObj(model)
|
215
|
+
|
216
|
+
flat_names = []
|
217
|
+
tensors = []
|
218
|
+
for tensor in model.subgraphs[0].tensors:
|
219
|
+
name = tensor.name.decode("utf-8")
|
220
|
+
assert name.startswith("lora_")
|
221
|
+
flat_names.append(name.split("lora_")[-1])
|
222
|
+
buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
|
223
|
+
arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
|
224
|
+
torch_tensor = torch.from_numpy(arr).to(dtype)
|
225
|
+
tensors.append(torch_tensor)
|
226
|
+
|
227
|
+
return _unflatten_lora(tensors, (flat_names, []))
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
def zeros(
|
231
|
+
cls,
|
232
|
+
rank: int,
|
233
|
+
config: model_config.ModelConfig,
|
234
|
+
dtype: torch.dtype = torch.float32,
|
235
|
+
) -> "LoRA":
|
236
|
+
"""Creates LoRA weights with zeros.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
rank: Rank of the LoRA weights.
|
240
|
+
config: Model configuration.
|
241
|
+
dtype: Data type of the LoRA weights. Currently only float32 is supported.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
LoRA weights with zeros.
|
245
|
+
"""
|
246
|
+
return cls._from_tensor_generator(
|
247
|
+
tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
|
248
|
+
rank=rank,
|
249
|
+
config=config,
|
250
|
+
dtype=dtype,
|
251
|
+
)
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def random(
|
255
|
+
cls,
|
256
|
+
rank: int,
|
257
|
+
config: model_config.ModelConfig,
|
258
|
+
dtype: torch.dtype = torch.float32,
|
259
|
+
) -> "LoRA":
|
260
|
+
"""Creates LoRA weights with random values.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
rank: Rank of the LoRA weights.
|
264
|
+
config: Model configuration.
|
265
|
+
dtype: Data type of the LoRA weights.
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
LoRA weights with random values.
|
269
|
+
"""
|
270
|
+
return cls._from_tensor_generator(
|
271
|
+
tensor_generator=lambda shape, dtype: torch.randint(
|
272
|
+
low=0, high=128, size=shape, dtype=dtype
|
273
|
+
),
|
274
|
+
rank=rank,
|
275
|
+
config=config,
|
276
|
+
dtype=dtype,
|
277
|
+
)
|
278
|
+
|
279
|
+
@classmethod
|
280
|
+
def _from_tensor_generator(
|
281
|
+
cls,
|
282
|
+
tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
|
283
|
+
rank: int,
|
284
|
+
config: model_config.ModelConfig,
|
285
|
+
dtype: torch.dtype = torch.float32,
|
286
|
+
) -> "LoRA":
|
287
|
+
"""Creates LoRA weights from a tensor generator."""
|
288
|
+
adapters = []
|
289
|
+
|
290
|
+
for i in range(config.num_layers):
|
291
|
+
block_config = config.block_config(i)
|
292
|
+
q_per_kv = (
|
293
|
+
block_config.attn_config.num_heads
|
294
|
+
// block_config.attn_config.num_query_groups
|
295
|
+
)
|
296
|
+
q_out_dim = q_per_kv * block_config.attn_config.head_dim
|
297
|
+
k_out_dim = v_out_dim = block_config.attn_config.head_dim
|
298
|
+
attention_lora = AttentionLoRA(
|
299
|
+
query=LoRAWeight(
|
300
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
301
|
+
b_prime=tensor_generator((rank, q_out_dim), dtype),
|
302
|
+
),
|
303
|
+
key=LoRAWeight(
|
304
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
305
|
+
b_prime=tensor_generator((rank, k_out_dim), dtype),
|
306
|
+
),
|
307
|
+
value=LoRAWeight(
|
308
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
309
|
+
b_prime=tensor_generator((rank, v_out_dim), dtype),
|
310
|
+
),
|
311
|
+
output=LoRAWeight(
|
312
|
+
a_prime=tensor_generator(
|
313
|
+
(
|
314
|
+
block_config.attn_config.num_heads
|
315
|
+
* block_config.attn_config.head_dim,
|
316
|
+
rank,
|
317
|
+
),
|
318
|
+
dtype,
|
319
|
+
),
|
320
|
+
b_prime=tensor_generator((rank, config.embedding_dim), dtype),
|
321
|
+
),
|
322
|
+
)
|
323
|
+
adapters.append(LoRAEntry(attention=attention_lora))
|
324
|
+
return cls(adapters=adapters)
|
325
|
+
|
326
|
+
def to_tflite(self) -> bytearray:
|
327
|
+
"""Converts LoRA to FlatBuffers."""
|
328
|
+
return _lora_to_flatbuffers(self)
|
329
|
+
|
330
|
+
|
331
|
+
def apply_lora(
|
332
|
+
x: torch.Tensor,
|
333
|
+
lora_weight: LoRAWeight,
|
334
|
+
shape: Optional[Tuple[int, ...]] = None,
|
335
|
+
) -> torch.Tensor:
|
336
|
+
"""Applies LoRA weights to a tensor.
|
337
|
+
|
338
|
+
Args:
|
339
|
+
x: Input tensor.
|
340
|
+
lora_weight: LoRA weight.
|
341
|
+
shape: Output shape. If None, the output shape is the same as the input
|
342
|
+
shape.
|
343
|
+
|
344
|
+
Returns:
|
345
|
+
Output tensor.
|
346
|
+
"""
|
347
|
+
output = torch.matmul(
|
348
|
+
torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
|
349
|
+
)
|
350
|
+
if shape is not None:
|
351
|
+
output = output.reshape(shape)
|
352
|
+
return output
|
353
|
+
|
354
|
+
|
355
|
+
def _flatten_attention_lora(
|
356
|
+
lora: AttentionLoRA, block_index: int
|
357
|
+
) -> Tuple[List[torch.Tensor], List[str]]:
|
358
|
+
"""Flattens LoRA weights for attention module."""
|
359
|
+
flattened = []
|
360
|
+
flat_names = []
|
361
|
+
flattened.append(lora.query.a_prime)
|
362
|
+
flat_names.append(f"atten_q_a_prime_weight_{block_index}")
|
363
|
+
flattened.append(lora.query.b_prime)
|
364
|
+
flat_names.append(f"atten_q_b_prime_weight_{block_index}")
|
365
|
+
flattened.append(lora.key.a_prime)
|
366
|
+
flat_names.append(f"atten_k_a_prime_weight_{block_index}")
|
367
|
+
flattened.append(lora.key.b_prime)
|
368
|
+
flat_names.append(f"atten_k_b_prime_weight_{block_index}")
|
369
|
+
flattened.append(lora.value.a_prime)
|
370
|
+
flat_names.append(f"atten_v_a_prime_weight_{block_index}")
|
371
|
+
flattened.append(lora.value.b_prime)
|
372
|
+
flat_names.append(f"atten_v_b_prime_weight_{block_index}")
|
373
|
+
flattened.append(lora.output.a_prime)
|
374
|
+
flat_names.append(f"atten_o_a_prime_weight_{block_index}")
|
375
|
+
flattened.append(lora.output.b_prime)
|
376
|
+
flat_names.append(f"atten_o_b_prime_weight_{block_index}")
|
377
|
+
return flattened, flat_names
|
378
|
+
|
379
|
+
|
380
|
+
def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
|
381
|
+
"""Flattens LoRA weights."""
|
382
|
+
flattened = []
|
383
|
+
flat_names = []
|
384
|
+
none_names = []
|
385
|
+
for i, entry in enumerate(lora.adapters):
|
386
|
+
attn_flattened, attn_flat_names = _flatten_attention_lora(
|
387
|
+
lora=entry.attention, block_index=i
|
388
|
+
)
|
389
|
+
flattened.extend(attn_flattened)
|
390
|
+
flat_names.extend(attn_flat_names)
|
391
|
+
return flattened, [flat_names, none_names]
|
392
|
+
|
393
|
+
|
394
|
+
def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
|
395
|
+
"""Flattens LoRA weights with keys."""
|
396
|
+
flattened, (flat_names, _) = _flatten_lora(lora)
|
397
|
+
return [
|
398
|
+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
399
|
+
], flat_names
|
400
|
+
|
401
|
+
|
402
|
+
def _unflatten_lora(
|
403
|
+
values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
|
404
|
+
) -> LoRA:
|
405
|
+
"""Unflattens LoRA object."""
|
406
|
+
flat_names, _ = context
|
407
|
+
names_weights = list(zip(flat_names, values))
|
408
|
+
adapters = {}
|
409
|
+
while names_weights:
|
410
|
+
name, weight = names_weights.pop(0)
|
411
|
+
block_idx = int(name.split("_")[-1])
|
412
|
+
if block_idx not in adapters:
|
413
|
+
adapters[block_idx] = LoRAEntry(
|
414
|
+
attention=AttentionLoRA(
|
415
|
+
query=LoRAWeight(
|
416
|
+
a_prime=None,
|
417
|
+
b_prime=None,
|
418
|
+
),
|
419
|
+
key=LoRAWeight(
|
420
|
+
a_prime=None,
|
421
|
+
b_prime=None,
|
422
|
+
),
|
423
|
+
value=LoRAWeight(
|
424
|
+
a_prime=None,
|
425
|
+
b_prime=None,
|
426
|
+
),
|
427
|
+
output=LoRAWeight(
|
428
|
+
a_prime=None,
|
429
|
+
b_prime=None,
|
430
|
+
),
|
431
|
+
)
|
432
|
+
)
|
433
|
+
|
434
|
+
if name.startswith("atten_"):
|
435
|
+
if "q_a_prime" in name:
|
436
|
+
adapters[block_idx].attention.query.a_prime = weight
|
437
|
+
elif "q_b_prime" in name:
|
438
|
+
adapters[block_idx].attention.query.b_prime = weight
|
439
|
+
elif "k_a_prime" in name:
|
440
|
+
adapters[block_idx].attention.key.a_prime = weight
|
441
|
+
elif "k_b_prime" in name:
|
442
|
+
adapters[block_idx].attention.key.b_prime = weight
|
443
|
+
elif "v_a_prime" in name:
|
444
|
+
adapters[block_idx].attention.value.a_prime = weight
|
445
|
+
elif "v_b_prime" in name:
|
446
|
+
adapters[block_idx].attention.value.b_prime = weight
|
447
|
+
elif "o_a_prime" in name:
|
448
|
+
adapters[block_idx].attention.output.a_prime = weight
|
449
|
+
elif "o_b_prime" in name:
|
450
|
+
adapters[block_idx].attention.output.b_prime = weight
|
451
|
+
else:
|
452
|
+
raise ValueError(f"Unsupported name: {name}")
|
453
|
+
else:
|
454
|
+
raise ValueError(f"Unsupported name: {name}")
|
455
|
+
|
456
|
+
return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
|
457
|
+
|
458
|
+
|
459
|
+
pytree.register_pytree_node(
|
460
|
+
LoRA,
|
461
|
+
_flatten_lora,
|
462
|
+
_unflatten_lora,
|
463
|
+
flatten_with_keys_fn=_flatten_lora_with_keys,
|
464
|
+
serialized_type_name="",
|
465
|
+
)
|
466
|
+
|
467
|
+
|
468
|
+
def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
|
469
|
+
"""Adds a buffer to the FlatBuffers."""
|
470
|
+
if data is not None:
|
471
|
+
assert data.dtype == np.float32
|
472
|
+
schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
|
473
|
+
for value in reversed(data.flatten().tolist()):
|
474
|
+
builder.PrependFloat32(value)
|
475
|
+
data_offset = builder.EndVector()
|
476
|
+
else:
|
477
|
+
schema_fb.BufferStartDataVector(builder, 0)
|
478
|
+
data_offset = builder.EndVector()
|
479
|
+
|
480
|
+
schema_fb.BufferStart(builder)
|
481
|
+
schema_fb.BufferAddData(builder, data_offset)
|
482
|
+
buffer_offset = schema_fb.BufferEnd(builder)
|
483
|
+
return buffer_offset
|
484
|
+
|
485
|
+
|
486
|
+
def _add_tensor(
|
487
|
+
builder: flatbuffers.Builder,
|
488
|
+
name: str,
|
489
|
+
shape: Tuple[int, ...],
|
490
|
+
buffer_idx: int,
|
491
|
+
) -> int:
|
492
|
+
"""Adds a tensor to the FlatBuffers."""
|
493
|
+
name_offset = builder.CreateString(name)
|
494
|
+
schema_fb.TensorStartShapeVector(builder, len(shape))
|
495
|
+
for dim in reversed(shape):
|
496
|
+
builder.PrependInt32(dim)
|
497
|
+
shape_offset = builder.EndVector()
|
498
|
+
schema_fb.TensorStart(builder)
|
499
|
+
schema_fb.TensorAddName(builder, name_offset)
|
500
|
+
schema_fb.TensorAddShape(builder, shape_offset)
|
501
|
+
schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
|
502
|
+
schema_fb.TensorAddBuffer(builder, buffer_idx)
|
503
|
+
tensor_offset = schema_fb.TensorEnd(builder)
|
504
|
+
return tensor_offset
|
505
|
+
|
506
|
+
|
507
|
+
def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
|
508
|
+
"""Converts LoRA to FlatBuffers."""
|
509
|
+
tensors, (names, _) = _flatten_lora(lora)
|
510
|
+
# Need to manually add the "lora_" prefix to the names here. The export will
|
511
|
+
# add the prefix automatically.
|
512
|
+
names = [f"lora_{name}" for name in names]
|
513
|
+
builder = flatbuffers.Builder(4096)
|
514
|
+
|
515
|
+
# Convention to add an empty buffer in the beginning.
|
516
|
+
buffer_offsets = [_add_buffer(builder, None)]
|
517
|
+
for tensor in tensors:
|
518
|
+
buffer_offsets.append(
|
519
|
+
_add_buffer(builder, tensor.detach().type(torch.float32).numpy())
|
520
|
+
)
|
521
|
+
|
522
|
+
schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
|
523
|
+
for buffer_offset in reversed(buffer_offsets):
|
524
|
+
builder.PrependUOffsetTRelative(buffer_offset)
|
525
|
+
buffers_offset = builder.EndVector()
|
526
|
+
|
527
|
+
tensor_offsets = []
|
528
|
+
for i, (name, tensor) in enumerate(zip(names, tensors)):
|
529
|
+
# Note that the zeroth buffer is empty and reserved for the convention.
|
530
|
+
tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
|
531
|
+
|
532
|
+
schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
|
533
|
+
for tensor_offset in reversed(tensor_offsets):
|
534
|
+
builder.PrependUOffsetTRelative(tensor_offset)
|
535
|
+
tensors_offset = builder.EndVector()
|
536
|
+
|
537
|
+
string_offset = builder.CreateString("lora_params")
|
538
|
+
schema_fb.SubGraphStart(builder)
|
539
|
+
schema_fb.SubGraphAddName(builder, string_offset)
|
540
|
+
schema_fb.SubGraphAddTensors(builder, tensors_offset)
|
541
|
+
subgraph_offset = schema_fb.SubGraphEnd(builder)
|
542
|
+
|
543
|
+
schema_fb.ModelStartSubgraphsVector(builder, 1)
|
544
|
+
builder.PrependUOffsetTRelative(subgraph_offset)
|
545
|
+
subgraphs_offset = builder.EndVector()
|
546
|
+
|
547
|
+
string_offset = builder.CreateString("lora_params")
|
548
|
+
schema_fb.ModelStart(builder)
|
549
|
+
schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
|
550
|
+
schema_fb.ModelAddDescription(builder, string_offset)
|
551
|
+
schema_fb.ModelAddBuffers(builder, buffers_offset)
|
552
|
+
schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
|
553
|
+
model_offset = schema_fb.ModelEnd(builder)
|
554
|
+
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
555
|
+
flatbuffer_model = builder.Output()
|
556
|
+
|
557
|
+
return flatbuffer_model
|