ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,557 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """LoRA weights for generative models.
17
+
18
+ The current implementation support attention only lora. Additionally, we expect
19
+ lora weights for all projections within the attention module (i.e., Q, K, V, O).
20
+ """
21
+
22
+ import dataclasses
23
+ from typing import Any, Callable, List, Optional, Tuple
24
+
25
+ from ai_edge_torch.generative.layers import model_config
26
+ import flatbuffers
27
+ import numpy as np
28
+ import safetensors
29
+ import torch
30
+ import torch.utils._pytree as pytree
31
+
32
+ from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
33
+
34
+ _TFLITE_SCHEMA_VERSION = 3
35
+ _TFLITE_FILE_IDENTIFIER = b"TFL3"
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class LoRAWeight:
40
+ """LoRA weight per projection. The weights are pre-transposed."""
41
+
42
+ a_prime: torch.Tensor
43
+ b_prime: torch.Tensor
44
+
45
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
46
+ if not isinstance(other, LoRAWeight):
47
+ return False
48
+ if self.a_prime.shape != other.a_prime.shape:
49
+ return False
50
+ if self.b_prime.shape != other.b_prime.shape:
51
+ return False
52
+ return torch.allclose(
53
+ self.a_prime, other.a_prime, rtol=rtol, atol=atol
54
+ ) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
55
+
56
+
57
+ @dataclasses.dataclass
58
+ class AttentionLoRA:
59
+ """LoRA weights for attention module."""
60
+
61
+ query: LoRAWeight
62
+ key: LoRAWeight
63
+ value: LoRAWeight
64
+ output: LoRAWeight
65
+
66
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
67
+ if not isinstance(other, AttentionLoRA):
68
+ return False
69
+ return (
70
+ self.query.__eq__(other.query, rtol=rtol, atol=atol)
71
+ and self.key.__eq__(other.key, rtol=rtol, atol=atol)
72
+ and self.value.__eq__(other.value, rtol=rtol, atol=atol)
73
+ and self.output.__eq__(other.output, rtol=rtol, atol=atol)
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class LoRAEntry:
79
+ """LoRA weights for a single layer."""
80
+
81
+ attention: AttentionLoRA
82
+
83
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
84
+ if not isinstance(other, LoRAEntry):
85
+ return False
86
+ return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
87
+
88
+
89
+ @dataclasses.dataclass
90
+ class LoRATensorNames:
91
+ """Tensor names for LoRA weights."""
92
+
93
+ attn_query_w_a: str
94
+ attn_query_w_b: str
95
+
96
+ attn_key_w_a: str
97
+ attn_key_w_b: str
98
+
99
+ attn_value_w_a: str
100
+ attn_value_w_b: str
101
+
102
+ attn_output_w_a: str
103
+ attn_output_w_b: str
104
+
105
+
106
+ @dataclasses.dataclass
107
+ class LoRA:
108
+ """LoRA weights for all modules."""
109
+
110
+ adapters: Tuple[LoRAEntry, ...]
111
+
112
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
113
+ if not isinstance(other, LoRA):
114
+ return False
115
+ if len(self.adapters) != len(other.adapters):
116
+ return False
117
+ return all(
118
+ adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
119
+ for adapter, other_adapter in zip(self.adapters, other.adapters)
120
+ )
121
+
122
+ def get_rank(self) -> int:
123
+ """Returns the rank of the LoRA weights."""
124
+ return self.adapters[0].attention.query.a_prime.shape[1]
125
+
126
+ @classmethod
127
+ def from_safetensors(
128
+ cls,
129
+ path: str,
130
+ scale: float,
131
+ config: model_config.ModelConfig,
132
+ lora_tensor_names: LoRATensorNames,
133
+ dtype: torch.dtype = torch.float32,
134
+ ) -> "LoRA":
135
+ """Creates LoRA weights from a Hugging Face model.
136
+
137
+ Args:
138
+ path: Path to the model.
139
+ scale: Scale factor for the LoRA weights (applied only to one of the
140
+ projections). The scaling factor depnds on the training configuration.
141
+ The common values are either `lora_alpha / rank` or `lora_alpha /
142
+ sqrt(rank)`.
143
+ config: Model configuration.
144
+ lora_tensor_names: Tensor names for the LoRA weights.
145
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
146
+
147
+ Returns:
148
+ LoRA weights for all modules.
149
+ """
150
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
151
+ adapters = []
152
+ for i in range(config.num_layers):
153
+ attention_lora = AttentionLoRA(
154
+ query=LoRAWeight(
155
+ a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
156
+ .to(dtype)
157
+ .T
158
+ * scale,
159
+ b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
160
+ .to(dtype)
161
+ .T,
162
+ ),
163
+ key=LoRAWeight(
164
+ a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
165
+ .to(dtype)
166
+ .T
167
+ * scale,
168
+ b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
169
+ .to(dtype)
170
+ .T,
171
+ ),
172
+ value=LoRAWeight(
173
+ a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
174
+ .to(dtype)
175
+ .T
176
+ * scale,
177
+ b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
178
+ .to(dtype)
179
+ .T,
180
+ ),
181
+ output=LoRAWeight(
182
+ a_prime=f.get_tensor(
183
+ lora_tensor_names.attn_output_w_a.format(i)
184
+ )
185
+ .to(dtype)
186
+ .T
187
+ * scale,
188
+ b_prime=f.get_tensor(
189
+ lora_tensor_names.attn_output_w_b.format(i)
190
+ )
191
+ .to(dtype)
192
+ .T,
193
+ ),
194
+ )
195
+ adapters.append(LoRAEntry(attention=attention_lora))
196
+ return cls(adapters=adapters)
197
+
198
+ @classmethod
199
+ def from_flatbuffers(
200
+ cls,
201
+ flatbuffer_model: bytearray,
202
+ dtype: torch.dtype = torch.float32,
203
+ ) -> "LoRA":
204
+ """Creates LoRA weights from FlatBuffers.
205
+
206
+ Args:
207
+ flatbuffer_model: FlatBuffers model.
208
+ dtype: Data type of the LoRA weights.
209
+
210
+ Returns:
211
+ LoRA weights for all modules.
212
+ """
213
+ model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
214
+ model = schema_fb.ModelT.InitFromObj(model)
215
+
216
+ flat_names = []
217
+ tensors = []
218
+ for tensor in model.subgraphs[0].tensors:
219
+ name = tensor.name.decode("utf-8")
220
+ assert name.startswith("lora_")
221
+ flat_names.append(name.split("lora_")[-1])
222
+ buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
223
+ arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
224
+ torch_tensor = torch.from_numpy(arr).to(dtype)
225
+ tensors.append(torch_tensor)
226
+
227
+ return _unflatten_lora(tensors, (flat_names, []))
228
+
229
+ @classmethod
230
+ def zeros(
231
+ cls,
232
+ rank: int,
233
+ config: model_config.ModelConfig,
234
+ dtype: torch.dtype = torch.float32,
235
+ ) -> "LoRA":
236
+ """Creates LoRA weights with zeros.
237
+
238
+ Args:
239
+ rank: Rank of the LoRA weights.
240
+ config: Model configuration.
241
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
242
+
243
+ Returns:
244
+ LoRA weights with zeros.
245
+ """
246
+ return cls._from_tensor_generator(
247
+ tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
248
+ rank=rank,
249
+ config=config,
250
+ dtype=dtype,
251
+ )
252
+
253
+ @classmethod
254
+ def random(
255
+ cls,
256
+ rank: int,
257
+ config: model_config.ModelConfig,
258
+ dtype: torch.dtype = torch.float32,
259
+ ) -> "LoRA":
260
+ """Creates LoRA weights with random values.
261
+
262
+ Args:
263
+ rank: Rank of the LoRA weights.
264
+ config: Model configuration.
265
+ dtype: Data type of the LoRA weights.
266
+
267
+ Returns:
268
+ LoRA weights with random values.
269
+ """
270
+ return cls._from_tensor_generator(
271
+ tensor_generator=lambda shape, dtype: torch.randint(
272
+ low=0, high=128, size=shape, dtype=dtype
273
+ ),
274
+ rank=rank,
275
+ config=config,
276
+ dtype=dtype,
277
+ )
278
+
279
+ @classmethod
280
+ def _from_tensor_generator(
281
+ cls,
282
+ tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
283
+ rank: int,
284
+ config: model_config.ModelConfig,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> "LoRA":
287
+ """Creates LoRA weights from a tensor generator."""
288
+ adapters = []
289
+
290
+ for i in range(config.num_layers):
291
+ block_config = config.block_config(i)
292
+ q_per_kv = (
293
+ block_config.attn_config.num_heads
294
+ // block_config.attn_config.num_query_groups
295
+ )
296
+ q_out_dim = q_per_kv * block_config.attn_config.head_dim
297
+ k_out_dim = v_out_dim = block_config.attn_config.head_dim
298
+ attention_lora = AttentionLoRA(
299
+ query=LoRAWeight(
300
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
301
+ b_prime=tensor_generator((rank, q_out_dim), dtype),
302
+ ),
303
+ key=LoRAWeight(
304
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
305
+ b_prime=tensor_generator((rank, k_out_dim), dtype),
306
+ ),
307
+ value=LoRAWeight(
308
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
309
+ b_prime=tensor_generator((rank, v_out_dim), dtype),
310
+ ),
311
+ output=LoRAWeight(
312
+ a_prime=tensor_generator(
313
+ (
314
+ block_config.attn_config.num_heads
315
+ * block_config.attn_config.head_dim,
316
+ rank,
317
+ ),
318
+ dtype,
319
+ ),
320
+ b_prime=tensor_generator((rank, config.embedding_dim), dtype),
321
+ ),
322
+ )
323
+ adapters.append(LoRAEntry(attention=attention_lora))
324
+ return cls(adapters=adapters)
325
+
326
+ def to_tflite(self) -> bytearray:
327
+ """Converts LoRA to FlatBuffers."""
328
+ return _lora_to_flatbuffers(self)
329
+
330
+
331
+ def apply_lora(
332
+ x: torch.Tensor,
333
+ lora_weight: LoRAWeight,
334
+ shape: Optional[Tuple[int, ...]] = None,
335
+ ) -> torch.Tensor:
336
+ """Applies LoRA weights to a tensor.
337
+
338
+ Args:
339
+ x: Input tensor.
340
+ lora_weight: LoRA weight.
341
+ shape: Output shape. If None, the output shape is the same as the input
342
+ shape.
343
+
344
+ Returns:
345
+ Output tensor.
346
+ """
347
+ output = torch.matmul(
348
+ torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
349
+ )
350
+ if shape is not None:
351
+ output = output.reshape(shape)
352
+ return output
353
+
354
+
355
+ def _flatten_attention_lora(
356
+ lora: AttentionLoRA, block_index: int
357
+ ) -> Tuple[List[torch.Tensor], List[str]]:
358
+ """Flattens LoRA weights for attention module."""
359
+ flattened = []
360
+ flat_names = []
361
+ flattened.append(lora.query.a_prime)
362
+ flat_names.append(f"atten_q_a_prime_weight_{block_index}")
363
+ flattened.append(lora.query.b_prime)
364
+ flat_names.append(f"atten_q_b_prime_weight_{block_index}")
365
+ flattened.append(lora.key.a_prime)
366
+ flat_names.append(f"atten_k_a_prime_weight_{block_index}")
367
+ flattened.append(lora.key.b_prime)
368
+ flat_names.append(f"atten_k_b_prime_weight_{block_index}")
369
+ flattened.append(lora.value.a_prime)
370
+ flat_names.append(f"atten_v_a_prime_weight_{block_index}")
371
+ flattened.append(lora.value.b_prime)
372
+ flat_names.append(f"atten_v_b_prime_weight_{block_index}")
373
+ flattened.append(lora.output.a_prime)
374
+ flat_names.append(f"atten_o_a_prime_weight_{block_index}")
375
+ flattened.append(lora.output.b_prime)
376
+ flat_names.append(f"atten_o_b_prime_weight_{block_index}")
377
+ return flattened, flat_names
378
+
379
+
380
+ def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
381
+ """Flattens LoRA weights."""
382
+ flattened = []
383
+ flat_names = []
384
+ none_names = []
385
+ for i, entry in enumerate(lora.adapters):
386
+ attn_flattened, attn_flat_names = _flatten_attention_lora(
387
+ lora=entry.attention, block_index=i
388
+ )
389
+ flattened.extend(attn_flattened)
390
+ flat_names.extend(attn_flat_names)
391
+ return flattened, [flat_names, none_names]
392
+
393
+
394
+ def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
395
+ """Flattens LoRA weights with keys."""
396
+ flattened, (flat_names, _) = _flatten_lora(lora)
397
+ return [
398
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
399
+ ], flat_names
400
+
401
+
402
+ def _unflatten_lora(
403
+ values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
404
+ ) -> LoRA:
405
+ """Unflattens LoRA object."""
406
+ flat_names, _ = context
407
+ names_weights = list(zip(flat_names, values))
408
+ adapters = {}
409
+ while names_weights:
410
+ name, weight = names_weights.pop(0)
411
+ block_idx = int(name.split("_")[-1])
412
+ if block_idx not in adapters:
413
+ adapters[block_idx] = LoRAEntry(
414
+ attention=AttentionLoRA(
415
+ query=LoRAWeight(
416
+ a_prime=None,
417
+ b_prime=None,
418
+ ),
419
+ key=LoRAWeight(
420
+ a_prime=None,
421
+ b_prime=None,
422
+ ),
423
+ value=LoRAWeight(
424
+ a_prime=None,
425
+ b_prime=None,
426
+ ),
427
+ output=LoRAWeight(
428
+ a_prime=None,
429
+ b_prime=None,
430
+ ),
431
+ )
432
+ )
433
+
434
+ if name.startswith("atten_"):
435
+ if "q_a_prime" in name:
436
+ adapters[block_idx].attention.query.a_prime = weight
437
+ elif "q_b_prime" in name:
438
+ adapters[block_idx].attention.query.b_prime = weight
439
+ elif "k_a_prime" in name:
440
+ adapters[block_idx].attention.key.a_prime = weight
441
+ elif "k_b_prime" in name:
442
+ adapters[block_idx].attention.key.b_prime = weight
443
+ elif "v_a_prime" in name:
444
+ adapters[block_idx].attention.value.a_prime = weight
445
+ elif "v_b_prime" in name:
446
+ adapters[block_idx].attention.value.b_prime = weight
447
+ elif "o_a_prime" in name:
448
+ adapters[block_idx].attention.output.a_prime = weight
449
+ elif "o_b_prime" in name:
450
+ adapters[block_idx].attention.output.b_prime = weight
451
+ else:
452
+ raise ValueError(f"Unsupported name: {name}")
453
+ else:
454
+ raise ValueError(f"Unsupported name: {name}")
455
+
456
+ return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
457
+
458
+
459
+ pytree.register_pytree_node(
460
+ LoRA,
461
+ _flatten_lora,
462
+ _unflatten_lora,
463
+ flatten_with_keys_fn=_flatten_lora_with_keys,
464
+ serialized_type_name="",
465
+ )
466
+
467
+
468
+ def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
469
+ """Adds a buffer to the FlatBuffers."""
470
+ if data is not None:
471
+ assert data.dtype == np.float32
472
+ schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
473
+ for value in reversed(data.flatten().tolist()):
474
+ builder.PrependFloat32(value)
475
+ data_offset = builder.EndVector()
476
+ else:
477
+ schema_fb.BufferStartDataVector(builder, 0)
478
+ data_offset = builder.EndVector()
479
+
480
+ schema_fb.BufferStart(builder)
481
+ schema_fb.BufferAddData(builder, data_offset)
482
+ buffer_offset = schema_fb.BufferEnd(builder)
483
+ return buffer_offset
484
+
485
+
486
+ def _add_tensor(
487
+ builder: flatbuffers.Builder,
488
+ name: str,
489
+ shape: Tuple[int, ...],
490
+ buffer_idx: int,
491
+ ) -> int:
492
+ """Adds a tensor to the FlatBuffers."""
493
+ name_offset = builder.CreateString(name)
494
+ schema_fb.TensorStartShapeVector(builder, len(shape))
495
+ for dim in reversed(shape):
496
+ builder.PrependInt32(dim)
497
+ shape_offset = builder.EndVector()
498
+ schema_fb.TensorStart(builder)
499
+ schema_fb.TensorAddName(builder, name_offset)
500
+ schema_fb.TensorAddShape(builder, shape_offset)
501
+ schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
502
+ schema_fb.TensorAddBuffer(builder, buffer_idx)
503
+ tensor_offset = schema_fb.TensorEnd(builder)
504
+ return tensor_offset
505
+
506
+
507
+ def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
508
+ """Converts LoRA to FlatBuffers."""
509
+ tensors, (names, _) = _flatten_lora(lora)
510
+ # Need to manually add the "lora_" prefix to the names here. The export will
511
+ # add the prefix automatically.
512
+ names = [f"lora_{name}" for name in names]
513
+ builder = flatbuffers.Builder(4096)
514
+
515
+ # Convention to add an empty buffer in the beginning.
516
+ buffer_offsets = [_add_buffer(builder, None)]
517
+ for tensor in tensors:
518
+ buffer_offsets.append(
519
+ _add_buffer(builder, tensor.detach().type(torch.float32).numpy())
520
+ )
521
+
522
+ schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
523
+ for buffer_offset in reversed(buffer_offsets):
524
+ builder.PrependUOffsetTRelative(buffer_offset)
525
+ buffers_offset = builder.EndVector()
526
+
527
+ tensor_offsets = []
528
+ for i, (name, tensor) in enumerate(zip(names, tensors)):
529
+ # Note that the zeroth buffer is empty and reserved for the convention.
530
+ tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
531
+
532
+ schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
533
+ for tensor_offset in reversed(tensor_offsets):
534
+ builder.PrependUOffsetTRelative(tensor_offset)
535
+ tensors_offset = builder.EndVector()
536
+
537
+ string_offset = builder.CreateString("lora_params")
538
+ schema_fb.SubGraphStart(builder)
539
+ schema_fb.SubGraphAddName(builder, string_offset)
540
+ schema_fb.SubGraphAddTensors(builder, tensors_offset)
541
+ subgraph_offset = schema_fb.SubGraphEnd(builder)
542
+
543
+ schema_fb.ModelStartSubgraphsVector(builder, 1)
544
+ builder.PrependUOffsetTRelative(subgraph_offset)
545
+ subgraphs_offset = builder.EndVector()
546
+
547
+ string_offset = builder.CreateString("lora_params")
548
+ schema_fb.ModelStart(builder)
549
+ schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
550
+ schema_fb.ModelAddDescription(builder, string_offset)
551
+ schema_fb.ModelAddBuffers(builder, buffers_offset)
552
+ schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
553
+ model_offset = schema_fb.ModelEnd(builder)
554
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
555
+ flatbuffer_model = builder.Output()
556
+
557
+ return flatbuffer_model