ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  4. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  5. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  6. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  7. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  9. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  10. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  11. ai_edge_torch/generative/layers/attention.py +41 -8
  12. ai_edge_torch/generative/layers/lora.py +557 -0
  13. ai_edge_torch/generative/test/test_lora.py +147 -0
  14. ai_edge_torch/generative/utilities/converter.py +100 -47
  15. ai_edge_torch/generative/utilities/model_builder.py +7 -2
  16. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  17. ai_edge_torch/odml_torch/export.py +6 -2
  18. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +24 -22
  22. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """LoRA weights for generative models.
17
+
18
+ The current implementation support attention only lora. Additionally, we expect
19
+ lora weights for all projections within the attention module (i.e., Q, K, V, O).
20
+ """
21
+
22
+ import dataclasses
23
+ from typing import Any, Callable, List, Optional, Tuple
24
+
25
+ from ai_edge_torch.generative.layers import model_config
26
+ import flatbuffers
27
+ import numpy as np
28
+ import safetensors
29
+ import torch
30
+ import torch.utils._pytree as pytree
31
+
32
+ from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
33
+
34
+ _TFLITE_SCHEMA_VERSION = 3
35
+ _TFLITE_FILE_IDENTIFIER = b"TFL3"
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class LoRAWeight:
40
+ """LoRA weight per projection. The weights are pre-transposed."""
41
+
42
+ a_prime: torch.Tensor
43
+ b_prime: torch.Tensor
44
+
45
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
46
+ if not isinstance(other, LoRAWeight):
47
+ return False
48
+ if self.a_prime.shape != other.a_prime.shape:
49
+ return False
50
+ if self.b_prime.shape != other.b_prime.shape:
51
+ return False
52
+ return torch.allclose(
53
+ self.a_prime, other.a_prime, rtol=rtol, atol=atol
54
+ ) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
55
+
56
+
57
+ @dataclasses.dataclass
58
+ class AttentionLoRA:
59
+ """LoRA weights for attention module."""
60
+
61
+ query: LoRAWeight
62
+ key: LoRAWeight
63
+ value: LoRAWeight
64
+ output: LoRAWeight
65
+
66
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
67
+ if not isinstance(other, AttentionLoRA):
68
+ return False
69
+ return (
70
+ self.query.__eq__(other.query, rtol=rtol, atol=atol)
71
+ and self.key.__eq__(other.key, rtol=rtol, atol=atol)
72
+ and self.value.__eq__(other.value, rtol=rtol, atol=atol)
73
+ and self.output.__eq__(other.output, rtol=rtol, atol=atol)
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class LoRAEntry:
79
+ """LoRA weights for a single layer."""
80
+
81
+ attention: AttentionLoRA
82
+
83
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
84
+ if not isinstance(other, LoRAEntry):
85
+ return False
86
+ return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
87
+
88
+
89
+ @dataclasses.dataclass
90
+ class LoRATensorNames:
91
+ """Tensor names for LoRA weights."""
92
+
93
+ attn_query_w_a: str
94
+ attn_query_w_b: str
95
+
96
+ attn_key_w_a: str
97
+ attn_key_w_b: str
98
+
99
+ attn_value_w_a: str
100
+ attn_value_w_b: str
101
+
102
+ attn_output_w_a: str
103
+ attn_output_w_b: str
104
+
105
+
106
+ @dataclasses.dataclass
107
+ class LoRA:
108
+ """LoRA weights for all modules."""
109
+
110
+ adapters: Tuple[LoRAEntry, ...]
111
+
112
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
113
+ if not isinstance(other, LoRA):
114
+ return False
115
+ if len(self.adapters) != len(other.adapters):
116
+ return False
117
+ return all(
118
+ adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
119
+ for adapter, other_adapter in zip(self.adapters, other.adapters)
120
+ )
121
+
122
+ def get_rank(self) -> int:
123
+ """Returns the rank of the LoRA weights."""
124
+ return self.adapters[0].attention.query.a_prime.shape[1]
125
+
126
+ @classmethod
127
+ def from_safetensors(
128
+ cls,
129
+ path: str,
130
+ scale: float,
131
+ config: model_config.ModelConfig,
132
+ lora_tensor_names: LoRATensorNames,
133
+ dtype: torch.dtype = torch.float32,
134
+ ) -> "LoRA":
135
+ """Creates LoRA weights from a Hugging Face model.
136
+
137
+ Args:
138
+ path: Path to the model.
139
+ scale: Scale factor for the LoRA weights (applied only to one of the
140
+ projections). The scaling factor depnds on the training configuration.
141
+ The common values are either `lora_alpha / rank` or `lora_alpha /
142
+ sqrt(rank)`.
143
+ config: Model configuration.
144
+ lora_tensor_names: Tensor names for the LoRA weights.
145
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
146
+
147
+ Returns:
148
+ LoRA weights for all modules.
149
+ """
150
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
151
+ adapters = []
152
+ for i in range(config.num_layers):
153
+ attention_lora = AttentionLoRA(
154
+ query=LoRAWeight(
155
+ a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
156
+ .to(dtype)
157
+ .T
158
+ * scale,
159
+ b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
160
+ .to(dtype)
161
+ .T,
162
+ ),
163
+ key=LoRAWeight(
164
+ a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
165
+ .to(dtype)
166
+ .T
167
+ * scale,
168
+ b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
169
+ .to(dtype)
170
+ .T,
171
+ ),
172
+ value=LoRAWeight(
173
+ a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
174
+ .to(dtype)
175
+ .T
176
+ * scale,
177
+ b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
178
+ .to(dtype)
179
+ .T,
180
+ ),
181
+ output=LoRAWeight(
182
+ a_prime=f.get_tensor(
183
+ lora_tensor_names.attn_output_w_a.format(i)
184
+ )
185
+ .to(dtype)
186
+ .T
187
+ * scale,
188
+ b_prime=f.get_tensor(
189
+ lora_tensor_names.attn_output_w_b.format(i)
190
+ )
191
+ .to(dtype)
192
+ .T,
193
+ ),
194
+ )
195
+ adapters.append(LoRAEntry(attention=attention_lora))
196
+ return cls(adapters=adapters)
197
+
198
+ @classmethod
199
+ def from_flatbuffers(
200
+ cls,
201
+ flatbuffer_model: bytearray,
202
+ dtype: torch.dtype = torch.float32,
203
+ ) -> "LoRA":
204
+ """Creates LoRA weights from FlatBuffers.
205
+
206
+ Args:
207
+ flatbuffer_model: FlatBuffers model.
208
+ dtype: Data type of the LoRA weights.
209
+
210
+ Returns:
211
+ LoRA weights for all modules.
212
+ """
213
+ model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
214
+ model = schema_fb.ModelT.InitFromObj(model)
215
+
216
+ flat_names = []
217
+ tensors = []
218
+ for tensor in model.subgraphs[0].tensors:
219
+ name = tensor.name.decode("utf-8")
220
+ assert name.startswith("lora_")
221
+ flat_names.append(name.split("lora_")[-1])
222
+ buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
223
+ arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
224
+ torch_tensor = torch.from_numpy(arr).to(dtype)
225
+ tensors.append(torch_tensor)
226
+
227
+ return _unflatten_lora(tensors, (flat_names, []))
228
+
229
+ @classmethod
230
+ def zeros(
231
+ cls,
232
+ rank: int,
233
+ config: model_config.ModelConfig,
234
+ dtype: torch.dtype = torch.float32,
235
+ ) -> "LoRA":
236
+ """Creates LoRA weights with zeros.
237
+
238
+ Args:
239
+ rank: Rank of the LoRA weights.
240
+ config: Model configuration.
241
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
242
+
243
+ Returns:
244
+ LoRA weights with zeros.
245
+ """
246
+ return cls._from_tensor_generator(
247
+ tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
248
+ rank=rank,
249
+ config=config,
250
+ dtype=dtype,
251
+ )
252
+
253
+ @classmethod
254
+ def random(
255
+ cls,
256
+ rank: int,
257
+ config: model_config.ModelConfig,
258
+ dtype: torch.dtype = torch.float32,
259
+ ) -> "LoRA":
260
+ """Creates LoRA weights with random values.
261
+
262
+ Args:
263
+ rank: Rank of the LoRA weights.
264
+ config: Model configuration.
265
+ dtype: Data type of the LoRA weights.
266
+
267
+ Returns:
268
+ LoRA weights with random values.
269
+ """
270
+ return cls._from_tensor_generator(
271
+ tensor_generator=lambda shape, dtype: torch.randint(
272
+ low=0, high=128, size=shape, dtype=dtype
273
+ ),
274
+ rank=rank,
275
+ config=config,
276
+ dtype=dtype,
277
+ )
278
+
279
+ @classmethod
280
+ def _from_tensor_generator(
281
+ cls,
282
+ tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
283
+ rank: int,
284
+ config: model_config.ModelConfig,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> "LoRA":
287
+ """Creates LoRA weights from a tensor generator."""
288
+ adapters = []
289
+
290
+ for i in range(config.num_layers):
291
+ block_config = config.block_config(i)
292
+ q_per_kv = (
293
+ block_config.attn_config.num_heads
294
+ // block_config.attn_config.num_query_groups
295
+ )
296
+ q_out_dim = q_per_kv * block_config.attn_config.head_dim
297
+ k_out_dim = v_out_dim = block_config.attn_config.head_dim
298
+ attention_lora = AttentionLoRA(
299
+ query=LoRAWeight(
300
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
301
+ b_prime=tensor_generator((rank, q_out_dim), dtype),
302
+ ),
303
+ key=LoRAWeight(
304
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
305
+ b_prime=tensor_generator((rank, k_out_dim), dtype),
306
+ ),
307
+ value=LoRAWeight(
308
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
309
+ b_prime=tensor_generator((rank, v_out_dim), dtype),
310
+ ),
311
+ output=LoRAWeight(
312
+ a_prime=tensor_generator(
313
+ (
314
+ block_config.attn_config.num_heads
315
+ * block_config.attn_config.head_dim,
316
+ rank,
317
+ ),
318
+ dtype,
319
+ ),
320
+ b_prime=tensor_generator((rank, config.embedding_dim), dtype),
321
+ ),
322
+ )
323
+ adapters.append(LoRAEntry(attention=attention_lora))
324
+ return cls(adapters=adapters)
325
+
326
+ def to_tflite(self) -> bytearray:
327
+ """Converts LoRA to FlatBuffers."""
328
+ return _lora_to_flatbuffers(self)
329
+
330
+
331
+ def apply_lora(
332
+ x: torch.Tensor,
333
+ lora_weight: LoRAWeight,
334
+ shape: Optional[Tuple[int, ...]] = None,
335
+ ) -> torch.Tensor:
336
+ """Applies LoRA weights to a tensor.
337
+
338
+ Args:
339
+ x: Input tensor.
340
+ lora_weight: LoRA weight.
341
+ shape: Output shape. If None, the output shape is the same as the input
342
+ shape.
343
+
344
+ Returns:
345
+ Output tensor.
346
+ """
347
+ output = torch.matmul(
348
+ torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
349
+ )
350
+ if shape is not None:
351
+ output = output.reshape(shape)
352
+ return output
353
+
354
+
355
+ def _flatten_attention_lora(
356
+ lora: AttentionLoRA, block_index: int
357
+ ) -> Tuple[List[torch.Tensor], List[str]]:
358
+ """Flattens LoRA weights for attention module."""
359
+ flattened = []
360
+ flat_names = []
361
+ flattened.append(lora.query.a_prime)
362
+ flat_names.append(f"atten_q_a_prime_weight_{block_index}")
363
+ flattened.append(lora.query.b_prime)
364
+ flat_names.append(f"atten_q_b_prime_weight_{block_index}")
365
+ flattened.append(lora.key.a_prime)
366
+ flat_names.append(f"atten_k_a_prime_weight_{block_index}")
367
+ flattened.append(lora.key.b_prime)
368
+ flat_names.append(f"atten_k_b_prime_weight_{block_index}")
369
+ flattened.append(lora.value.a_prime)
370
+ flat_names.append(f"atten_v_a_prime_weight_{block_index}")
371
+ flattened.append(lora.value.b_prime)
372
+ flat_names.append(f"atten_v_b_prime_weight_{block_index}")
373
+ flattened.append(lora.output.a_prime)
374
+ flat_names.append(f"atten_o_a_prime_weight_{block_index}")
375
+ flattened.append(lora.output.b_prime)
376
+ flat_names.append(f"atten_o_b_prime_weight_{block_index}")
377
+ return flattened, flat_names
378
+
379
+
380
+ def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
381
+ """Flattens LoRA weights."""
382
+ flattened = []
383
+ flat_names = []
384
+ none_names = []
385
+ for i, entry in enumerate(lora.adapters):
386
+ attn_flattened, attn_flat_names = _flatten_attention_lora(
387
+ lora=entry.attention, block_index=i
388
+ )
389
+ flattened.extend(attn_flattened)
390
+ flat_names.extend(attn_flat_names)
391
+ return flattened, [flat_names, none_names]
392
+
393
+
394
+ def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
395
+ """Flattens LoRA weights with keys."""
396
+ flattened, (flat_names, _) = _flatten_lora(lora)
397
+ return [
398
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
399
+ ], flat_names
400
+
401
+
402
+ def _unflatten_lora(
403
+ values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
404
+ ) -> LoRA:
405
+ """Unflattens LoRA object."""
406
+ flat_names, _ = context
407
+ names_weights = list(zip(flat_names, values))
408
+ adapters = {}
409
+ while names_weights:
410
+ name, weight = names_weights.pop(0)
411
+ block_idx = int(name.split("_")[-1])
412
+ if block_idx not in adapters:
413
+ adapters[block_idx] = LoRAEntry(
414
+ attention=AttentionLoRA(
415
+ query=LoRAWeight(
416
+ a_prime=None,
417
+ b_prime=None,
418
+ ),
419
+ key=LoRAWeight(
420
+ a_prime=None,
421
+ b_prime=None,
422
+ ),
423
+ value=LoRAWeight(
424
+ a_prime=None,
425
+ b_prime=None,
426
+ ),
427
+ output=LoRAWeight(
428
+ a_prime=None,
429
+ b_prime=None,
430
+ ),
431
+ )
432
+ )
433
+
434
+ if name.startswith("atten_"):
435
+ if "q_a_prime" in name:
436
+ adapters[block_idx].attention.query.a_prime = weight
437
+ elif "q_b_prime" in name:
438
+ adapters[block_idx].attention.query.b_prime = weight
439
+ elif "k_a_prime" in name:
440
+ adapters[block_idx].attention.key.a_prime = weight
441
+ elif "k_b_prime" in name:
442
+ adapters[block_idx].attention.key.b_prime = weight
443
+ elif "v_a_prime" in name:
444
+ adapters[block_idx].attention.value.a_prime = weight
445
+ elif "v_b_prime" in name:
446
+ adapters[block_idx].attention.value.b_prime = weight
447
+ elif "o_a_prime" in name:
448
+ adapters[block_idx].attention.output.a_prime = weight
449
+ elif "o_b_prime" in name:
450
+ adapters[block_idx].attention.output.b_prime = weight
451
+ else:
452
+ raise ValueError(f"Unsupported name: {name}")
453
+ else:
454
+ raise ValueError(f"Unsupported name: {name}")
455
+
456
+ return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
457
+
458
+
459
+ pytree.register_pytree_node(
460
+ LoRA,
461
+ _flatten_lora,
462
+ _unflatten_lora,
463
+ flatten_with_keys_fn=_flatten_lora_with_keys,
464
+ serialized_type_name="",
465
+ )
466
+
467
+
468
+ def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
469
+ """Adds a buffer to the FlatBuffers."""
470
+ if data is not None:
471
+ assert data.dtype == np.float32
472
+ schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
473
+ for value in reversed(data.flatten().tolist()):
474
+ builder.PrependFloat32(value)
475
+ data_offset = builder.EndVector()
476
+ else:
477
+ schema_fb.BufferStartDataVector(builder, 0)
478
+ data_offset = builder.EndVector()
479
+
480
+ schema_fb.BufferStart(builder)
481
+ schema_fb.BufferAddData(builder, data_offset)
482
+ buffer_offset = schema_fb.BufferEnd(builder)
483
+ return buffer_offset
484
+
485
+
486
+ def _add_tensor(
487
+ builder: flatbuffers.Builder,
488
+ name: str,
489
+ shape: Tuple[int, ...],
490
+ buffer_idx: int,
491
+ ) -> int:
492
+ """Adds a tensor to the FlatBuffers."""
493
+ name_offset = builder.CreateString(name)
494
+ schema_fb.TensorStartShapeVector(builder, len(shape))
495
+ for dim in reversed(shape):
496
+ builder.PrependInt32(dim)
497
+ shape_offset = builder.EndVector()
498
+ schema_fb.TensorStart(builder)
499
+ schema_fb.TensorAddName(builder, name_offset)
500
+ schema_fb.TensorAddShape(builder, shape_offset)
501
+ schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
502
+ schema_fb.TensorAddBuffer(builder, buffer_idx)
503
+ tensor_offset = schema_fb.TensorEnd(builder)
504
+ return tensor_offset
505
+
506
+
507
+ def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
508
+ """Converts LoRA to FlatBuffers."""
509
+ tensors, (names, _) = _flatten_lora(lora)
510
+ # Need to manually add the "lora_" prefix to the names here. The export will
511
+ # add the prefix automatically.
512
+ names = [f"lora_{name}" for name in names]
513
+ builder = flatbuffers.Builder(4096)
514
+
515
+ # Convention to add an empty buffer in the beginning.
516
+ buffer_offsets = [_add_buffer(builder, None)]
517
+ for tensor in tensors:
518
+ buffer_offsets.append(
519
+ _add_buffer(builder, tensor.detach().type(torch.float32).numpy())
520
+ )
521
+
522
+ schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
523
+ for buffer_offset in reversed(buffer_offsets):
524
+ builder.PrependUOffsetTRelative(buffer_offset)
525
+ buffers_offset = builder.EndVector()
526
+
527
+ tensor_offsets = []
528
+ for i, (name, tensor) in enumerate(zip(names, tensors)):
529
+ # Note that the zeroth buffer is empty and reserved for the convention.
530
+ tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
531
+
532
+ schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
533
+ for tensor_offset in reversed(tensor_offsets):
534
+ builder.PrependUOffsetTRelative(tensor_offset)
535
+ tensors_offset = builder.EndVector()
536
+
537
+ string_offset = builder.CreateString("lora_params")
538
+ schema_fb.SubGraphStart(builder)
539
+ schema_fb.SubGraphAddName(builder, string_offset)
540
+ schema_fb.SubGraphAddTensors(builder, tensors_offset)
541
+ subgraph_offset = schema_fb.SubGraphEnd(builder)
542
+
543
+ schema_fb.ModelStartSubgraphsVector(builder, 1)
544
+ builder.PrependUOffsetTRelative(subgraph_offset)
545
+ subgraphs_offset = builder.EndVector()
546
+
547
+ string_offset = builder.CreateString("lora_params")
548
+ schema_fb.ModelStart(builder)
549
+ schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
550
+ schema_fb.ModelAddDescription(builder, string_offset)
551
+ schema_fb.ModelAddBuffers(builder, buffers_offset)
552
+ schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
553
+ model_offset = schema_fb.ModelEnd(builder)
554
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
555
+ flatbuffer_model = builder.Output()
556
+
557
+ return flatbuffer_model
@@ -0,0 +1,147 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A suite of tests to validate LoRA utilities."""
17
+
18
+ from ai_edge_torch.generative.layers import lora as lora_utils
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ import torch
21
+ from absl.testing import absltest as googletest
22
+ from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ class TestLora(googletest.TestCase):
26
+ """Tests for LoRA utilities."""
27
+
28
+ def test_safetensors_builder(self):
29
+ """Converts a safetensors file to a LoRA module."""
30
+
31
+ tensor_names = lora_utils.LoRATensorNames(
32
+ attn_query_w_a=(
33
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
34
+ ),
35
+ attn_query_w_b=(
36
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
37
+ ),
38
+ attn_key_w_a=(
39
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
40
+ ),
41
+ attn_key_w_b=(
42
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
43
+ ),
44
+ attn_value_w_a=(
45
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
46
+ ),
47
+ attn_value_w_b=(
48
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
49
+ ),
50
+ attn_output_w_a=(
51
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
52
+ ),
53
+ attn_output_w_b=(
54
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
55
+ ),
56
+ )
57
+
58
+ safetensors_file = resource_loader.get_path_to_datafile(
59
+ "fixtures/test_lora_rank16.safetensors"
60
+ )
61
+ config = self._get_test_config(
62
+ num_layers=1,
63
+ head_dim=8,
64
+ num_query_groups=1,
65
+ kv_cache_max_len=16,
66
+ )
67
+ lora = lora_utils.LoRA.from_safetensors(
68
+ safetensors_file,
69
+ scale=1.0,
70
+ lora_tensor_names=tensor_names,
71
+ config=config,
72
+ )
73
+ self.assertEqual(lora.get_rank(), 16)
74
+
75
+ def test_torch_export(self):
76
+ """Tests the export of the LoRA module."""
77
+
78
+ class TestModel(torch.nn.Module):
79
+
80
+ def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
81
+ x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
82
+ return x
83
+
84
+ n = 1
85
+ head_dim = 2
86
+ num_query_groups = 1
87
+ key_length = 4
88
+ config = self._get_test_config(
89
+ num_layers=n,
90
+ head_dim=head_dim,
91
+ num_query_groups=num_query_groups,
92
+ kv_cache_max_len=key_length,
93
+ )
94
+ inputs = torch.zeros((n, 1, head_dim))
95
+ lora = lora_utils.LoRA.zeros(rank=16, config=config)
96
+ model = TestModel()
97
+ exported_program = torch.export.export(model, (inputs, lora))
98
+ input_specs = exported_program.graph_signature.input_specs
99
+ # 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
100
+ # 2 for output lora.
101
+ self.assertLen(input_specs, 9)
102
+ self.assertEqual(input_specs[0].arg.name, "x")
103
+ self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
104
+ self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
105
+ self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
106
+ self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
107
+ self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
108
+ self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
109
+ self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
110
+ self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
111
+
112
+ def test_lora_tflite_serialization(self):
113
+ """Tests the serialization of the LoRA module."""
114
+ config = self._get_test_config(
115
+ num_layers=2,
116
+ head_dim=8,
117
+ num_query_groups=1,
118
+ kv_cache_max_len=16,
119
+ )
120
+ lora = lora_utils.LoRA.random(rank=16, config=config)
121
+ flatbuffer_model = lora.to_tflite()
122
+ recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
123
+ self.assertEqual(lora, recovered_lora)
124
+
125
+ def _get_test_config(
126
+ self, num_layers, head_dim, num_query_groups, kv_cache_max_len
127
+ ):
128
+ """Returns a test model config."""
129
+ attn_config = cfg.AttentionConfig(
130
+ num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
131
+ )
132
+ block_config = cfg.TransformerBlockConfig(
133
+ attn_config=attn_config, ff_config=None
134
+ )
135
+ config = cfg.ModelConfig(
136
+ kv_cache_max_len=kv_cache_max_len,
137
+ embedding_dim=head_dim,
138
+ block_configs=block_config,
139
+ num_layers=num_layers,
140
+ max_seq_len=None,
141
+ vocab_size=None,
142
+ )
143
+ return config
144
+
145
+
146
+ if __name__ == "__main__":
147
+ googletest.main()