ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +41 -8
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +7 -2
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +24 -22
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""LoRA weights for generative models.
|
17
|
+
|
18
|
+
The current implementation support attention only lora. Additionally, we expect
|
19
|
+
lora weights for all projections within the attention module (i.e., Q, K, V, O).
|
20
|
+
"""
|
21
|
+
|
22
|
+
import dataclasses
|
23
|
+
from typing import Any, Callable, List, Optional, Tuple
|
24
|
+
|
25
|
+
from ai_edge_torch.generative.layers import model_config
|
26
|
+
import flatbuffers
|
27
|
+
import numpy as np
|
28
|
+
import safetensors
|
29
|
+
import torch
|
30
|
+
import torch.utils._pytree as pytree
|
31
|
+
|
32
|
+
from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
|
33
|
+
|
34
|
+
_TFLITE_SCHEMA_VERSION = 3
|
35
|
+
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
36
|
+
|
37
|
+
|
38
|
+
@dataclasses.dataclass
|
39
|
+
class LoRAWeight:
|
40
|
+
"""LoRA weight per projection. The weights are pre-transposed."""
|
41
|
+
|
42
|
+
a_prime: torch.Tensor
|
43
|
+
b_prime: torch.Tensor
|
44
|
+
|
45
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
46
|
+
if not isinstance(other, LoRAWeight):
|
47
|
+
return False
|
48
|
+
if self.a_prime.shape != other.a_prime.shape:
|
49
|
+
return False
|
50
|
+
if self.b_prime.shape != other.b_prime.shape:
|
51
|
+
return False
|
52
|
+
return torch.allclose(
|
53
|
+
self.a_prime, other.a_prime, rtol=rtol, atol=atol
|
54
|
+
) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
|
55
|
+
|
56
|
+
|
57
|
+
@dataclasses.dataclass
|
58
|
+
class AttentionLoRA:
|
59
|
+
"""LoRA weights for attention module."""
|
60
|
+
|
61
|
+
query: LoRAWeight
|
62
|
+
key: LoRAWeight
|
63
|
+
value: LoRAWeight
|
64
|
+
output: LoRAWeight
|
65
|
+
|
66
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
67
|
+
if not isinstance(other, AttentionLoRA):
|
68
|
+
return False
|
69
|
+
return (
|
70
|
+
self.query.__eq__(other.query, rtol=rtol, atol=atol)
|
71
|
+
and self.key.__eq__(other.key, rtol=rtol, atol=atol)
|
72
|
+
and self.value.__eq__(other.value, rtol=rtol, atol=atol)
|
73
|
+
and self.output.__eq__(other.output, rtol=rtol, atol=atol)
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
@dataclasses.dataclass
|
78
|
+
class LoRAEntry:
|
79
|
+
"""LoRA weights for a single layer."""
|
80
|
+
|
81
|
+
attention: AttentionLoRA
|
82
|
+
|
83
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
84
|
+
if not isinstance(other, LoRAEntry):
|
85
|
+
return False
|
86
|
+
return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
|
87
|
+
|
88
|
+
|
89
|
+
@dataclasses.dataclass
|
90
|
+
class LoRATensorNames:
|
91
|
+
"""Tensor names for LoRA weights."""
|
92
|
+
|
93
|
+
attn_query_w_a: str
|
94
|
+
attn_query_w_b: str
|
95
|
+
|
96
|
+
attn_key_w_a: str
|
97
|
+
attn_key_w_b: str
|
98
|
+
|
99
|
+
attn_value_w_a: str
|
100
|
+
attn_value_w_b: str
|
101
|
+
|
102
|
+
attn_output_w_a: str
|
103
|
+
attn_output_w_b: str
|
104
|
+
|
105
|
+
|
106
|
+
@dataclasses.dataclass
|
107
|
+
class LoRA:
|
108
|
+
"""LoRA weights for all modules."""
|
109
|
+
|
110
|
+
adapters: Tuple[LoRAEntry, ...]
|
111
|
+
|
112
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
113
|
+
if not isinstance(other, LoRA):
|
114
|
+
return False
|
115
|
+
if len(self.adapters) != len(other.adapters):
|
116
|
+
return False
|
117
|
+
return all(
|
118
|
+
adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
|
119
|
+
for adapter, other_adapter in zip(self.adapters, other.adapters)
|
120
|
+
)
|
121
|
+
|
122
|
+
def get_rank(self) -> int:
|
123
|
+
"""Returns the rank of the LoRA weights."""
|
124
|
+
return self.adapters[0].attention.query.a_prime.shape[1]
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def from_safetensors(
|
128
|
+
cls,
|
129
|
+
path: str,
|
130
|
+
scale: float,
|
131
|
+
config: model_config.ModelConfig,
|
132
|
+
lora_tensor_names: LoRATensorNames,
|
133
|
+
dtype: torch.dtype = torch.float32,
|
134
|
+
) -> "LoRA":
|
135
|
+
"""Creates LoRA weights from a Hugging Face model.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
path: Path to the model.
|
139
|
+
scale: Scale factor for the LoRA weights (applied only to one of the
|
140
|
+
projections). The scaling factor depnds on the training configuration.
|
141
|
+
The common values are either `lora_alpha / rank` or `lora_alpha /
|
142
|
+
sqrt(rank)`.
|
143
|
+
config: Model configuration.
|
144
|
+
lora_tensor_names: Tensor names for the LoRA weights.
|
145
|
+
dtype: Data type of the LoRA weights. Currently only float32 is supported.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
LoRA weights for all modules.
|
149
|
+
"""
|
150
|
+
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
151
|
+
adapters = []
|
152
|
+
for i in range(config.num_layers):
|
153
|
+
attention_lora = AttentionLoRA(
|
154
|
+
query=LoRAWeight(
|
155
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
|
156
|
+
.to(dtype)
|
157
|
+
.T
|
158
|
+
* scale,
|
159
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
|
160
|
+
.to(dtype)
|
161
|
+
.T,
|
162
|
+
),
|
163
|
+
key=LoRAWeight(
|
164
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
|
165
|
+
.to(dtype)
|
166
|
+
.T
|
167
|
+
* scale,
|
168
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
|
169
|
+
.to(dtype)
|
170
|
+
.T,
|
171
|
+
),
|
172
|
+
value=LoRAWeight(
|
173
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
|
174
|
+
.to(dtype)
|
175
|
+
.T
|
176
|
+
* scale,
|
177
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
|
178
|
+
.to(dtype)
|
179
|
+
.T,
|
180
|
+
),
|
181
|
+
output=LoRAWeight(
|
182
|
+
a_prime=f.get_tensor(
|
183
|
+
lora_tensor_names.attn_output_w_a.format(i)
|
184
|
+
)
|
185
|
+
.to(dtype)
|
186
|
+
.T
|
187
|
+
* scale,
|
188
|
+
b_prime=f.get_tensor(
|
189
|
+
lora_tensor_names.attn_output_w_b.format(i)
|
190
|
+
)
|
191
|
+
.to(dtype)
|
192
|
+
.T,
|
193
|
+
),
|
194
|
+
)
|
195
|
+
adapters.append(LoRAEntry(attention=attention_lora))
|
196
|
+
return cls(adapters=adapters)
|
197
|
+
|
198
|
+
@classmethod
|
199
|
+
def from_flatbuffers(
|
200
|
+
cls,
|
201
|
+
flatbuffer_model: bytearray,
|
202
|
+
dtype: torch.dtype = torch.float32,
|
203
|
+
) -> "LoRA":
|
204
|
+
"""Creates LoRA weights from FlatBuffers.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
flatbuffer_model: FlatBuffers model.
|
208
|
+
dtype: Data type of the LoRA weights.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
LoRA weights for all modules.
|
212
|
+
"""
|
213
|
+
model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
|
214
|
+
model = schema_fb.ModelT.InitFromObj(model)
|
215
|
+
|
216
|
+
flat_names = []
|
217
|
+
tensors = []
|
218
|
+
for tensor in model.subgraphs[0].tensors:
|
219
|
+
name = tensor.name.decode("utf-8")
|
220
|
+
assert name.startswith("lora_")
|
221
|
+
flat_names.append(name.split("lora_")[-1])
|
222
|
+
buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
|
223
|
+
arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
|
224
|
+
torch_tensor = torch.from_numpy(arr).to(dtype)
|
225
|
+
tensors.append(torch_tensor)
|
226
|
+
|
227
|
+
return _unflatten_lora(tensors, (flat_names, []))
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
def zeros(
|
231
|
+
cls,
|
232
|
+
rank: int,
|
233
|
+
config: model_config.ModelConfig,
|
234
|
+
dtype: torch.dtype = torch.float32,
|
235
|
+
) -> "LoRA":
|
236
|
+
"""Creates LoRA weights with zeros.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
rank: Rank of the LoRA weights.
|
240
|
+
config: Model configuration.
|
241
|
+
dtype: Data type of the LoRA weights. Currently only float32 is supported.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
LoRA weights with zeros.
|
245
|
+
"""
|
246
|
+
return cls._from_tensor_generator(
|
247
|
+
tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
|
248
|
+
rank=rank,
|
249
|
+
config=config,
|
250
|
+
dtype=dtype,
|
251
|
+
)
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def random(
|
255
|
+
cls,
|
256
|
+
rank: int,
|
257
|
+
config: model_config.ModelConfig,
|
258
|
+
dtype: torch.dtype = torch.float32,
|
259
|
+
) -> "LoRA":
|
260
|
+
"""Creates LoRA weights with random values.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
rank: Rank of the LoRA weights.
|
264
|
+
config: Model configuration.
|
265
|
+
dtype: Data type of the LoRA weights.
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
LoRA weights with random values.
|
269
|
+
"""
|
270
|
+
return cls._from_tensor_generator(
|
271
|
+
tensor_generator=lambda shape, dtype: torch.randint(
|
272
|
+
low=0, high=128, size=shape, dtype=dtype
|
273
|
+
),
|
274
|
+
rank=rank,
|
275
|
+
config=config,
|
276
|
+
dtype=dtype,
|
277
|
+
)
|
278
|
+
|
279
|
+
@classmethod
|
280
|
+
def _from_tensor_generator(
|
281
|
+
cls,
|
282
|
+
tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
|
283
|
+
rank: int,
|
284
|
+
config: model_config.ModelConfig,
|
285
|
+
dtype: torch.dtype = torch.float32,
|
286
|
+
) -> "LoRA":
|
287
|
+
"""Creates LoRA weights from a tensor generator."""
|
288
|
+
adapters = []
|
289
|
+
|
290
|
+
for i in range(config.num_layers):
|
291
|
+
block_config = config.block_config(i)
|
292
|
+
q_per_kv = (
|
293
|
+
block_config.attn_config.num_heads
|
294
|
+
// block_config.attn_config.num_query_groups
|
295
|
+
)
|
296
|
+
q_out_dim = q_per_kv * block_config.attn_config.head_dim
|
297
|
+
k_out_dim = v_out_dim = block_config.attn_config.head_dim
|
298
|
+
attention_lora = AttentionLoRA(
|
299
|
+
query=LoRAWeight(
|
300
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
301
|
+
b_prime=tensor_generator((rank, q_out_dim), dtype),
|
302
|
+
),
|
303
|
+
key=LoRAWeight(
|
304
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
305
|
+
b_prime=tensor_generator((rank, k_out_dim), dtype),
|
306
|
+
),
|
307
|
+
value=LoRAWeight(
|
308
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
309
|
+
b_prime=tensor_generator((rank, v_out_dim), dtype),
|
310
|
+
),
|
311
|
+
output=LoRAWeight(
|
312
|
+
a_prime=tensor_generator(
|
313
|
+
(
|
314
|
+
block_config.attn_config.num_heads
|
315
|
+
* block_config.attn_config.head_dim,
|
316
|
+
rank,
|
317
|
+
),
|
318
|
+
dtype,
|
319
|
+
),
|
320
|
+
b_prime=tensor_generator((rank, config.embedding_dim), dtype),
|
321
|
+
),
|
322
|
+
)
|
323
|
+
adapters.append(LoRAEntry(attention=attention_lora))
|
324
|
+
return cls(adapters=adapters)
|
325
|
+
|
326
|
+
def to_tflite(self) -> bytearray:
|
327
|
+
"""Converts LoRA to FlatBuffers."""
|
328
|
+
return _lora_to_flatbuffers(self)
|
329
|
+
|
330
|
+
|
331
|
+
def apply_lora(
|
332
|
+
x: torch.Tensor,
|
333
|
+
lora_weight: LoRAWeight,
|
334
|
+
shape: Optional[Tuple[int, ...]] = None,
|
335
|
+
) -> torch.Tensor:
|
336
|
+
"""Applies LoRA weights to a tensor.
|
337
|
+
|
338
|
+
Args:
|
339
|
+
x: Input tensor.
|
340
|
+
lora_weight: LoRA weight.
|
341
|
+
shape: Output shape. If None, the output shape is the same as the input
|
342
|
+
shape.
|
343
|
+
|
344
|
+
Returns:
|
345
|
+
Output tensor.
|
346
|
+
"""
|
347
|
+
output = torch.matmul(
|
348
|
+
torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
|
349
|
+
)
|
350
|
+
if shape is not None:
|
351
|
+
output = output.reshape(shape)
|
352
|
+
return output
|
353
|
+
|
354
|
+
|
355
|
+
def _flatten_attention_lora(
|
356
|
+
lora: AttentionLoRA, block_index: int
|
357
|
+
) -> Tuple[List[torch.Tensor], List[str]]:
|
358
|
+
"""Flattens LoRA weights for attention module."""
|
359
|
+
flattened = []
|
360
|
+
flat_names = []
|
361
|
+
flattened.append(lora.query.a_prime)
|
362
|
+
flat_names.append(f"atten_q_a_prime_weight_{block_index}")
|
363
|
+
flattened.append(lora.query.b_prime)
|
364
|
+
flat_names.append(f"atten_q_b_prime_weight_{block_index}")
|
365
|
+
flattened.append(lora.key.a_prime)
|
366
|
+
flat_names.append(f"atten_k_a_prime_weight_{block_index}")
|
367
|
+
flattened.append(lora.key.b_prime)
|
368
|
+
flat_names.append(f"atten_k_b_prime_weight_{block_index}")
|
369
|
+
flattened.append(lora.value.a_prime)
|
370
|
+
flat_names.append(f"atten_v_a_prime_weight_{block_index}")
|
371
|
+
flattened.append(lora.value.b_prime)
|
372
|
+
flat_names.append(f"atten_v_b_prime_weight_{block_index}")
|
373
|
+
flattened.append(lora.output.a_prime)
|
374
|
+
flat_names.append(f"atten_o_a_prime_weight_{block_index}")
|
375
|
+
flattened.append(lora.output.b_prime)
|
376
|
+
flat_names.append(f"atten_o_b_prime_weight_{block_index}")
|
377
|
+
return flattened, flat_names
|
378
|
+
|
379
|
+
|
380
|
+
def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
|
381
|
+
"""Flattens LoRA weights."""
|
382
|
+
flattened = []
|
383
|
+
flat_names = []
|
384
|
+
none_names = []
|
385
|
+
for i, entry in enumerate(lora.adapters):
|
386
|
+
attn_flattened, attn_flat_names = _flatten_attention_lora(
|
387
|
+
lora=entry.attention, block_index=i
|
388
|
+
)
|
389
|
+
flattened.extend(attn_flattened)
|
390
|
+
flat_names.extend(attn_flat_names)
|
391
|
+
return flattened, [flat_names, none_names]
|
392
|
+
|
393
|
+
|
394
|
+
def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
|
395
|
+
"""Flattens LoRA weights with keys."""
|
396
|
+
flattened, (flat_names, _) = _flatten_lora(lora)
|
397
|
+
return [
|
398
|
+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
399
|
+
], flat_names
|
400
|
+
|
401
|
+
|
402
|
+
def _unflatten_lora(
|
403
|
+
values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
|
404
|
+
) -> LoRA:
|
405
|
+
"""Unflattens LoRA object."""
|
406
|
+
flat_names, _ = context
|
407
|
+
names_weights = list(zip(flat_names, values))
|
408
|
+
adapters = {}
|
409
|
+
while names_weights:
|
410
|
+
name, weight = names_weights.pop(0)
|
411
|
+
block_idx = int(name.split("_")[-1])
|
412
|
+
if block_idx not in adapters:
|
413
|
+
adapters[block_idx] = LoRAEntry(
|
414
|
+
attention=AttentionLoRA(
|
415
|
+
query=LoRAWeight(
|
416
|
+
a_prime=None,
|
417
|
+
b_prime=None,
|
418
|
+
),
|
419
|
+
key=LoRAWeight(
|
420
|
+
a_prime=None,
|
421
|
+
b_prime=None,
|
422
|
+
),
|
423
|
+
value=LoRAWeight(
|
424
|
+
a_prime=None,
|
425
|
+
b_prime=None,
|
426
|
+
),
|
427
|
+
output=LoRAWeight(
|
428
|
+
a_prime=None,
|
429
|
+
b_prime=None,
|
430
|
+
),
|
431
|
+
)
|
432
|
+
)
|
433
|
+
|
434
|
+
if name.startswith("atten_"):
|
435
|
+
if "q_a_prime" in name:
|
436
|
+
adapters[block_idx].attention.query.a_prime = weight
|
437
|
+
elif "q_b_prime" in name:
|
438
|
+
adapters[block_idx].attention.query.b_prime = weight
|
439
|
+
elif "k_a_prime" in name:
|
440
|
+
adapters[block_idx].attention.key.a_prime = weight
|
441
|
+
elif "k_b_prime" in name:
|
442
|
+
adapters[block_idx].attention.key.b_prime = weight
|
443
|
+
elif "v_a_prime" in name:
|
444
|
+
adapters[block_idx].attention.value.a_prime = weight
|
445
|
+
elif "v_b_prime" in name:
|
446
|
+
adapters[block_idx].attention.value.b_prime = weight
|
447
|
+
elif "o_a_prime" in name:
|
448
|
+
adapters[block_idx].attention.output.a_prime = weight
|
449
|
+
elif "o_b_prime" in name:
|
450
|
+
adapters[block_idx].attention.output.b_prime = weight
|
451
|
+
else:
|
452
|
+
raise ValueError(f"Unsupported name: {name}")
|
453
|
+
else:
|
454
|
+
raise ValueError(f"Unsupported name: {name}")
|
455
|
+
|
456
|
+
return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
|
457
|
+
|
458
|
+
|
459
|
+
pytree.register_pytree_node(
|
460
|
+
LoRA,
|
461
|
+
_flatten_lora,
|
462
|
+
_unflatten_lora,
|
463
|
+
flatten_with_keys_fn=_flatten_lora_with_keys,
|
464
|
+
serialized_type_name="",
|
465
|
+
)
|
466
|
+
|
467
|
+
|
468
|
+
def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
|
469
|
+
"""Adds a buffer to the FlatBuffers."""
|
470
|
+
if data is not None:
|
471
|
+
assert data.dtype == np.float32
|
472
|
+
schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
|
473
|
+
for value in reversed(data.flatten().tolist()):
|
474
|
+
builder.PrependFloat32(value)
|
475
|
+
data_offset = builder.EndVector()
|
476
|
+
else:
|
477
|
+
schema_fb.BufferStartDataVector(builder, 0)
|
478
|
+
data_offset = builder.EndVector()
|
479
|
+
|
480
|
+
schema_fb.BufferStart(builder)
|
481
|
+
schema_fb.BufferAddData(builder, data_offset)
|
482
|
+
buffer_offset = schema_fb.BufferEnd(builder)
|
483
|
+
return buffer_offset
|
484
|
+
|
485
|
+
|
486
|
+
def _add_tensor(
|
487
|
+
builder: flatbuffers.Builder,
|
488
|
+
name: str,
|
489
|
+
shape: Tuple[int, ...],
|
490
|
+
buffer_idx: int,
|
491
|
+
) -> int:
|
492
|
+
"""Adds a tensor to the FlatBuffers."""
|
493
|
+
name_offset = builder.CreateString(name)
|
494
|
+
schema_fb.TensorStartShapeVector(builder, len(shape))
|
495
|
+
for dim in reversed(shape):
|
496
|
+
builder.PrependInt32(dim)
|
497
|
+
shape_offset = builder.EndVector()
|
498
|
+
schema_fb.TensorStart(builder)
|
499
|
+
schema_fb.TensorAddName(builder, name_offset)
|
500
|
+
schema_fb.TensorAddShape(builder, shape_offset)
|
501
|
+
schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
|
502
|
+
schema_fb.TensorAddBuffer(builder, buffer_idx)
|
503
|
+
tensor_offset = schema_fb.TensorEnd(builder)
|
504
|
+
return tensor_offset
|
505
|
+
|
506
|
+
|
507
|
+
def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
|
508
|
+
"""Converts LoRA to FlatBuffers."""
|
509
|
+
tensors, (names, _) = _flatten_lora(lora)
|
510
|
+
# Need to manually add the "lora_" prefix to the names here. The export will
|
511
|
+
# add the prefix automatically.
|
512
|
+
names = [f"lora_{name}" for name in names]
|
513
|
+
builder = flatbuffers.Builder(4096)
|
514
|
+
|
515
|
+
# Convention to add an empty buffer in the beginning.
|
516
|
+
buffer_offsets = [_add_buffer(builder, None)]
|
517
|
+
for tensor in tensors:
|
518
|
+
buffer_offsets.append(
|
519
|
+
_add_buffer(builder, tensor.detach().type(torch.float32).numpy())
|
520
|
+
)
|
521
|
+
|
522
|
+
schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
|
523
|
+
for buffer_offset in reversed(buffer_offsets):
|
524
|
+
builder.PrependUOffsetTRelative(buffer_offset)
|
525
|
+
buffers_offset = builder.EndVector()
|
526
|
+
|
527
|
+
tensor_offsets = []
|
528
|
+
for i, (name, tensor) in enumerate(zip(names, tensors)):
|
529
|
+
# Note that the zeroth buffer is empty and reserved for the convention.
|
530
|
+
tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
|
531
|
+
|
532
|
+
schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
|
533
|
+
for tensor_offset in reversed(tensor_offsets):
|
534
|
+
builder.PrependUOffsetTRelative(tensor_offset)
|
535
|
+
tensors_offset = builder.EndVector()
|
536
|
+
|
537
|
+
string_offset = builder.CreateString("lora_params")
|
538
|
+
schema_fb.SubGraphStart(builder)
|
539
|
+
schema_fb.SubGraphAddName(builder, string_offset)
|
540
|
+
schema_fb.SubGraphAddTensors(builder, tensors_offset)
|
541
|
+
subgraph_offset = schema_fb.SubGraphEnd(builder)
|
542
|
+
|
543
|
+
schema_fb.ModelStartSubgraphsVector(builder, 1)
|
544
|
+
builder.PrependUOffsetTRelative(subgraph_offset)
|
545
|
+
subgraphs_offset = builder.EndVector()
|
546
|
+
|
547
|
+
string_offset = builder.CreateString("lora_params")
|
548
|
+
schema_fb.ModelStart(builder)
|
549
|
+
schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
|
550
|
+
schema_fb.ModelAddDescription(builder, string_offset)
|
551
|
+
schema_fb.ModelAddBuffers(builder, buffers_offset)
|
552
|
+
schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
|
553
|
+
model_offset = schema_fb.ModelEnd(builder)
|
554
|
+
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
555
|
+
flatbuffer_model = builder.Output()
|
556
|
+
|
557
|
+
return flatbuffer_model
|
@@ -0,0 +1,147 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""A suite of tests to validate LoRA utilities."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
import torch
|
21
|
+
from absl.testing import absltest as googletest
|
22
|
+
from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
|
23
|
+
|
24
|
+
|
25
|
+
class TestLora(googletest.TestCase):
|
26
|
+
"""Tests for LoRA utilities."""
|
27
|
+
|
28
|
+
def test_safetensors_builder(self):
|
29
|
+
"""Converts a safetensors file to a LoRA module."""
|
30
|
+
|
31
|
+
tensor_names = lora_utils.LoRATensorNames(
|
32
|
+
attn_query_w_a=(
|
33
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
|
34
|
+
),
|
35
|
+
attn_query_w_b=(
|
36
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
|
37
|
+
),
|
38
|
+
attn_key_w_a=(
|
39
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
|
40
|
+
),
|
41
|
+
attn_key_w_b=(
|
42
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
|
43
|
+
),
|
44
|
+
attn_value_w_a=(
|
45
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
|
46
|
+
),
|
47
|
+
attn_value_w_b=(
|
48
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
|
49
|
+
),
|
50
|
+
attn_output_w_a=(
|
51
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
|
52
|
+
),
|
53
|
+
attn_output_w_b=(
|
54
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
|
55
|
+
),
|
56
|
+
)
|
57
|
+
|
58
|
+
safetensors_file = resource_loader.get_path_to_datafile(
|
59
|
+
"fixtures/test_lora_rank16.safetensors"
|
60
|
+
)
|
61
|
+
config = self._get_test_config(
|
62
|
+
num_layers=1,
|
63
|
+
head_dim=8,
|
64
|
+
num_query_groups=1,
|
65
|
+
kv_cache_max_len=16,
|
66
|
+
)
|
67
|
+
lora = lora_utils.LoRA.from_safetensors(
|
68
|
+
safetensors_file,
|
69
|
+
scale=1.0,
|
70
|
+
lora_tensor_names=tensor_names,
|
71
|
+
config=config,
|
72
|
+
)
|
73
|
+
self.assertEqual(lora.get_rank(), 16)
|
74
|
+
|
75
|
+
def test_torch_export(self):
|
76
|
+
"""Tests the export of the LoRA module."""
|
77
|
+
|
78
|
+
class TestModel(torch.nn.Module):
|
79
|
+
|
80
|
+
def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
|
81
|
+
x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
|
82
|
+
return x
|
83
|
+
|
84
|
+
n = 1
|
85
|
+
head_dim = 2
|
86
|
+
num_query_groups = 1
|
87
|
+
key_length = 4
|
88
|
+
config = self._get_test_config(
|
89
|
+
num_layers=n,
|
90
|
+
head_dim=head_dim,
|
91
|
+
num_query_groups=num_query_groups,
|
92
|
+
kv_cache_max_len=key_length,
|
93
|
+
)
|
94
|
+
inputs = torch.zeros((n, 1, head_dim))
|
95
|
+
lora = lora_utils.LoRA.zeros(rank=16, config=config)
|
96
|
+
model = TestModel()
|
97
|
+
exported_program = torch.export.export(model, (inputs, lora))
|
98
|
+
input_specs = exported_program.graph_signature.input_specs
|
99
|
+
# 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
|
100
|
+
# 2 for output lora.
|
101
|
+
self.assertLen(input_specs, 9)
|
102
|
+
self.assertEqual(input_specs[0].arg.name, "x")
|
103
|
+
self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
|
104
|
+
self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
|
105
|
+
self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
|
106
|
+
self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
|
107
|
+
self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
|
108
|
+
self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
|
109
|
+
self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
|
110
|
+
self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
|
111
|
+
|
112
|
+
def test_lora_tflite_serialization(self):
|
113
|
+
"""Tests the serialization of the LoRA module."""
|
114
|
+
config = self._get_test_config(
|
115
|
+
num_layers=2,
|
116
|
+
head_dim=8,
|
117
|
+
num_query_groups=1,
|
118
|
+
kv_cache_max_len=16,
|
119
|
+
)
|
120
|
+
lora = lora_utils.LoRA.random(rank=16, config=config)
|
121
|
+
flatbuffer_model = lora.to_tflite()
|
122
|
+
recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
|
123
|
+
self.assertEqual(lora, recovered_lora)
|
124
|
+
|
125
|
+
def _get_test_config(
|
126
|
+
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
127
|
+
):
|
128
|
+
"""Returns a test model config."""
|
129
|
+
attn_config = cfg.AttentionConfig(
|
130
|
+
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
131
|
+
)
|
132
|
+
block_config = cfg.TransformerBlockConfig(
|
133
|
+
attn_config=attn_config, ff_config=None
|
134
|
+
)
|
135
|
+
config = cfg.ModelConfig(
|
136
|
+
kv_cache_max_len=kv_cache_max_len,
|
137
|
+
embedding_dim=head_dim,
|
138
|
+
block_configs=block_config,
|
139
|
+
num_layers=num_layers,
|
140
|
+
max_seq_len=None,
|
141
|
+
vocab_size=None,
|
142
|
+
)
|
143
|
+
return config
|
144
|
+
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
googletest.main()
|