ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  4. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  5. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  6. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  7. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  9. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  10. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  11. ai_edge_torch/generative/layers/attention.py +41 -8
  12. ai_edge_torch/generative/layers/lora.py +557 -0
  13. ai_edge_torch/generative/test/test_lora.py +147 -0
  14. ai_edge_torch/generative/utilities/converter.py +100 -47
  15. ai_edge_torch/generative/utilities/model_builder.py +7 -2
  16. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  17. ai_edge_torch/odml_torch/export.py +6 -2
  18. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +24 -22
  22. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """LoRA weights for generative models.
17
+
18
+ The current implementation support attention only lora. Additionally, we expect
19
+ lora weights for all projections within the attention module (i.e., Q, K, V, O).
20
+ """
21
+
22
+ import dataclasses
23
+ from typing import Any, Callable, List, Optional, Tuple
24
+
25
+ from ai_edge_torch.generative.layers import model_config
26
+ import flatbuffers
27
+ import numpy as np
28
+ import safetensors
29
+ import torch
30
+ import torch.utils._pytree as pytree
31
+
32
+ from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
33
+
34
+ _TFLITE_SCHEMA_VERSION = 3
35
+ _TFLITE_FILE_IDENTIFIER = b"TFL3"
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class LoRAWeight:
40
+ """LoRA weight per projection. The weights are pre-transposed."""
41
+
42
+ a_prime: torch.Tensor
43
+ b_prime: torch.Tensor
44
+
45
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
46
+ if not isinstance(other, LoRAWeight):
47
+ return False
48
+ if self.a_prime.shape != other.a_prime.shape:
49
+ return False
50
+ if self.b_prime.shape != other.b_prime.shape:
51
+ return False
52
+ return torch.allclose(
53
+ self.a_prime, other.a_prime, rtol=rtol, atol=atol
54
+ ) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
55
+
56
+
57
+ @dataclasses.dataclass
58
+ class AttentionLoRA:
59
+ """LoRA weights for attention module."""
60
+
61
+ query: LoRAWeight
62
+ key: LoRAWeight
63
+ value: LoRAWeight
64
+ output: LoRAWeight
65
+
66
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
67
+ if not isinstance(other, AttentionLoRA):
68
+ return False
69
+ return (
70
+ self.query.__eq__(other.query, rtol=rtol, atol=atol)
71
+ and self.key.__eq__(other.key, rtol=rtol, atol=atol)
72
+ and self.value.__eq__(other.value, rtol=rtol, atol=atol)
73
+ and self.output.__eq__(other.output, rtol=rtol, atol=atol)
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class LoRAEntry:
79
+ """LoRA weights for a single layer."""
80
+
81
+ attention: AttentionLoRA
82
+
83
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
84
+ if not isinstance(other, LoRAEntry):
85
+ return False
86
+ return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
87
+
88
+
89
+ @dataclasses.dataclass
90
+ class LoRATensorNames:
91
+ """Tensor names for LoRA weights."""
92
+
93
+ attn_query_w_a: str
94
+ attn_query_w_b: str
95
+
96
+ attn_key_w_a: str
97
+ attn_key_w_b: str
98
+
99
+ attn_value_w_a: str
100
+ attn_value_w_b: str
101
+
102
+ attn_output_w_a: str
103
+ attn_output_w_b: str
104
+
105
+
106
+ @dataclasses.dataclass
107
+ class LoRA:
108
+ """LoRA weights for all modules."""
109
+
110
+ adapters: Tuple[LoRAEntry, ...]
111
+
112
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
113
+ if not isinstance(other, LoRA):
114
+ return False
115
+ if len(self.adapters) != len(other.adapters):
116
+ return False
117
+ return all(
118
+ adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
119
+ for adapter, other_adapter in zip(self.adapters, other.adapters)
120
+ )
121
+
122
+ def get_rank(self) -> int:
123
+ """Returns the rank of the LoRA weights."""
124
+ return self.adapters[0].attention.query.a_prime.shape[1]
125
+
126
+ @classmethod
127
+ def from_safetensors(
128
+ cls,
129
+ path: str,
130
+ scale: float,
131
+ config: model_config.ModelConfig,
132
+ lora_tensor_names: LoRATensorNames,
133
+ dtype: torch.dtype = torch.float32,
134
+ ) -> "LoRA":
135
+ """Creates LoRA weights from a Hugging Face model.
136
+
137
+ Args:
138
+ path: Path to the model.
139
+ scale: Scale factor for the LoRA weights (applied only to one of the
140
+ projections). The scaling factor depnds on the training configuration.
141
+ The common values are either `lora_alpha / rank` or `lora_alpha /
142
+ sqrt(rank)`.
143
+ config: Model configuration.
144
+ lora_tensor_names: Tensor names for the LoRA weights.
145
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
146
+
147
+ Returns:
148
+ LoRA weights for all modules.
149
+ """
150
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
151
+ adapters = []
152
+ for i in range(config.num_layers):
153
+ attention_lora = AttentionLoRA(
154
+ query=LoRAWeight(
155
+ a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
156
+ .to(dtype)
157
+ .T
158
+ * scale,
159
+ b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
160
+ .to(dtype)
161
+ .T,
162
+ ),
163
+ key=LoRAWeight(
164
+ a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
165
+ .to(dtype)
166
+ .T
167
+ * scale,
168
+ b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
169
+ .to(dtype)
170
+ .T,
171
+ ),
172
+ value=LoRAWeight(
173
+ a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
174
+ .to(dtype)
175
+ .T
176
+ * scale,
177
+ b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
178
+ .to(dtype)
179
+ .T,
180
+ ),
181
+ output=LoRAWeight(
182
+ a_prime=f.get_tensor(
183
+ lora_tensor_names.attn_output_w_a.format(i)
184
+ )
185
+ .to(dtype)
186
+ .T
187
+ * scale,
188
+ b_prime=f.get_tensor(
189
+ lora_tensor_names.attn_output_w_b.format(i)
190
+ )
191
+ .to(dtype)
192
+ .T,
193
+ ),
194
+ )
195
+ adapters.append(LoRAEntry(attention=attention_lora))
196
+ return cls(adapters=adapters)
197
+
198
+ @classmethod
199
+ def from_flatbuffers(
200
+ cls,
201
+ flatbuffer_model: bytearray,
202
+ dtype: torch.dtype = torch.float32,
203
+ ) -> "LoRA":
204
+ """Creates LoRA weights from FlatBuffers.
205
+
206
+ Args:
207
+ flatbuffer_model: FlatBuffers model.
208
+ dtype: Data type of the LoRA weights.
209
+
210
+ Returns:
211
+ LoRA weights for all modules.
212
+ """
213
+ model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
214
+ model = schema_fb.ModelT.InitFromObj(model)
215
+
216
+ flat_names = []
217
+ tensors = []
218
+ for tensor in model.subgraphs[0].tensors:
219
+ name = tensor.name.decode("utf-8")
220
+ assert name.startswith("lora_")
221
+ flat_names.append(name.split("lora_")[-1])
222
+ buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
223
+ arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
224
+ torch_tensor = torch.from_numpy(arr).to(dtype)
225
+ tensors.append(torch_tensor)
226
+
227
+ return _unflatten_lora(tensors, (flat_names, []))
228
+
229
+ @classmethod
230
+ def zeros(
231
+ cls,
232
+ rank: int,
233
+ config: model_config.ModelConfig,
234
+ dtype: torch.dtype = torch.float32,
235
+ ) -> "LoRA":
236
+ """Creates LoRA weights with zeros.
237
+
238
+ Args:
239
+ rank: Rank of the LoRA weights.
240
+ config: Model configuration.
241
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
242
+
243
+ Returns:
244
+ LoRA weights with zeros.
245
+ """
246
+ return cls._from_tensor_generator(
247
+ tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
248
+ rank=rank,
249
+ config=config,
250
+ dtype=dtype,
251
+ )
252
+
253
+ @classmethod
254
+ def random(
255
+ cls,
256
+ rank: int,
257
+ config: model_config.ModelConfig,
258
+ dtype: torch.dtype = torch.float32,
259
+ ) -> "LoRA":
260
+ """Creates LoRA weights with random values.
261
+
262
+ Args:
263
+ rank: Rank of the LoRA weights.
264
+ config: Model configuration.
265
+ dtype: Data type of the LoRA weights.
266
+
267
+ Returns:
268
+ LoRA weights with random values.
269
+ """
270
+ return cls._from_tensor_generator(
271
+ tensor_generator=lambda shape, dtype: torch.randint(
272
+ low=0, high=128, size=shape, dtype=dtype
273
+ ),
274
+ rank=rank,
275
+ config=config,
276
+ dtype=dtype,
277
+ )
278
+
279
+ @classmethod
280
+ def _from_tensor_generator(
281
+ cls,
282
+ tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
283
+ rank: int,
284
+ config: model_config.ModelConfig,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> "LoRA":
287
+ """Creates LoRA weights from a tensor generator."""
288
+ adapters = []
289
+
290
+ for i in range(config.num_layers):
291
+ block_config = config.block_config(i)
292
+ q_per_kv = (
293
+ block_config.attn_config.num_heads
294
+ // block_config.attn_config.num_query_groups
295
+ )
296
+ q_out_dim = q_per_kv * block_config.attn_config.head_dim
297
+ k_out_dim = v_out_dim = block_config.attn_config.head_dim
298
+ attention_lora = AttentionLoRA(
299
+ query=LoRAWeight(
300
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
301
+ b_prime=tensor_generator((rank, q_out_dim), dtype),
302
+ ),
303
+ key=LoRAWeight(
304
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
305
+ b_prime=tensor_generator((rank, k_out_dim), dtype),
306
+ ),
307
+ value=LoRAWeight(
308
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
309
+ b_prime=tensor_generator((rank, v_out_dim), dtype),
310
+ ),
311
+ output=LoRAWeight(
312
+ a_prime=tensor_generator(
313
+ (
314
+ block_config.attn_config.num_heads
315
+ * block_config.attn_config.head_dim,
316
+ rank,
317
+ ),
318
+ dtype,
319
+ ),
320
+ b_prime=tensor_generator((rank, config.embedding_dim), dtype),
321
+ ),
322
+ )
323
+ adapters.append(LoRAEntry(attention=attention_lora))
324
+ return cls(adapters=adapters)
325
+
326
+ def to_tflite(self) -> bytearray:
327
+ """Converts LoRA to FlatBuffers."""
328
+ return _lora_to_flatbuffers(self)
329
+
330
+
331
+ def apply_lora(
332
+ x: torch.Tensor,
333
+ lora_weight: LoRAWeight,
334
+ shape: Optional[Tuple[int, ...]] = None,
335
+ ) -> torch.Tensor:
336
+ """Applies LoRA weights to a tensor.
337
+
338
+ Args:
339
+ x: Input tensor.
340
+ lora_weight: LoRA weight.
341
+ shape: Output shape. If None, the output shape is the same as the input
342
+ shape.
343
+
344
+ Returns:
345
+ Output tensor.
346
+ """
347
+ output = torch.matmul(
348
+ torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
349
+ )
350
+ if shape is not None:
351
+ output = output.reshape(shape)
352
+ return output
353
+
354
+
355
+ def _flatten_attention_lora(
356
+ lora: AttentionLoRA, block_index: int
357
+ ) -> Tuple[List[torch.Tensor], List[str]]:
358
+ """Flattens LoRA weights for attention module."""
359
+ flattened = []
360
+ flat_names = []
361
+ flattened.append(lora.query.a_prime)
362
+ flat_names.append(f"atten_q_a_prime_weight_{block_index}")
363
+ flattened.append(lora.query.b_prime)
364
+ flat_names.append(f"atten_q_b_prime_weight_{block_index}")
365
+ flattened.append(lora.key.a_prime)
366
+ flat_names.append(f"atten_k_a_prime_weight_{block_index}")
367
+ flattened.append(lora.key.b_prime)
368
+ flat_names.append(f"atten_k_b_prime_weight_{block_index}")
369
+ flattened.append(lora.value.a_prime)
370
+ flat_names.append(f"atten_v_a_prime_weight_{block_index}")
371
+ flattened.append(lora.value.b_prime)
372
+ flat_names.append(f"atten_v_b_prime_weight_{block_index}")
373
+ flattened.append(lora.output.a_prime)
374
+ flat_names.append(f"atten_o_a_prime_weight_{block_index}")
375
+ flattened.append(lora.output.b_prime)
376
+ flat_names.append(f"atten_o_b_prime_weight_{block_index}")
377
+ return flattened, flat_names
378
+
379
+
380
+ def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
381
+ """Flattens LoRA weights."""
382
+ flattened = []
383
+ flat_names = []
384
+ none_names = []
385
+ for i, entry in enumerate(lora.adapters):
386
+ attn_flattened, attn_flat_names = _flatten_attention_lora(
387
+ lora=entry.attention, block_index=i
388
+ )
389
+ flattened.extend(attn_flattened)
390
+ flat_names.extend(attn_flat_names)
391
+ return flattened, [flat_names, none_names]
392
+
393
+
394
+ def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
395
+ """Flattens LoRA weights with keys."""
396
+ flattened, (flat_names, _) = _flatten_lora(lora)
397
+ return [
398
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
399
+ ], flat_names
400
+
401
+
402
+ def _unflatten_lora(
403
+ values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
404
+ ) -> LoRA:
405
+ """Unflattens LoRA object."""
406
+ flat_names, _ = context
407
+ names_weights = list(zip(flat_names, values))
408
+ adapters = {}
409
+ while names_weights:
410
+ name, weight = names_weights.pop(0)
411
+ block_idx = int(name.split("_")[-1])
412
+ if block_idx not in adapters:
413
+ adapters[block_idx] = LoRAEntry(
414
+ attention=AttentionLoRA(
415
+ query=LoRAWeight(
416
+ a_prime=None,
417
+ b_prime=None,
418
+ ),
419
+ key=LoRAWeight(
420
+ a_prime=None,
421
+ b_prime=None,
422
+ ),
423
+ value=LoRAWeight(
424
+ a_prime=None,
425
+ b_prime=None,
426
+ ),
427
+ output=LoRAWeight(
428
+ a_prime=None,
429
+ b_prime=None,
430
+ ),
431
+ )
432
+ )
433
+
434
+ if name.startswith("atten_"):
435
+ if "q_a_prime" in name:
436
+ adapters[block_idx].attention.query.a_prime = weight
437
+ elif "q_b_prime" in name:
438
+ adapters[block_idx].attention.query.b_prime = weight
439
+ elif "k_a_prime" in name:
440
+ adapters[block_idx].attention.key.a_prime = weight
441
+ elif "k_b_prime" in name:
442
+ adapters[block_idx].attention.key.b_prime = weight
443
+ elif "v_a_prime" in name:
444
+ adapters[block_idx].attention.value.a_prime = weight
445
+ elif "v_b_prime" in name:
446
+ adapters[block_idx].attention.value.b_prime = weight
447
+ elif "o_a_prime" in name:
448
+ adapters[block_idx].attention.output.a_prime = weight
449
+ elif "o_b_prime" in name:
450
+ adapters[block_idx].attention.output.b_prime = weight
451
+ else:
452
+ raise ValueError(f"Unsupported name: {name}")
453
+ else:
454
+ raise ValueError(f"Unsupported name: {name}")
455
+
456
+ return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
457
+
458
+
459
+ pytree.register_pytree_node(
460
+ LoRA,
461
+ _flatten_lora,
462
+ _unflatten_lora,
463
+ flatten_with_keys_fn=_flatten_lora_with_keys,
464
+ serialized_type_name="",
465
+ )
466
+
467
+
468
+ def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
469
+ """Adds a buffer to the FlatBuffers."""
470
+ if data is not None:
471
+ assert data.dtype == np.float32
472
+ schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
473
+ for value in reversed(data.flatten().tolist()):
474
+ builder.PrependFloat32(value)
475
+ data_offset = builder.EndVector()
476
+ else:
477
+ schema_fb.BufferStartDataVector(builder, 0)
478
+ data_offset = builder.EndVector()
479
+
480
+ schema_fb.BufferStart(builder)
481
+ schema_fb.BufferAddData(builder, data_offset)
482
+ buffer_offset = schema_fb.BufferEnd(builder)
483
+ return buffer_offset
484
+
485
+
486
+ def _add_tensor(
487
+ builder: flatbuffers.Builder,
488
+ name: str,
489
+ shape: Tuple[int, ...],
490
+ buffer_idx: int,
491
+ ) -> int:
492
+ """Adds a tensor to the FlatBuffers."""
493
+ name_offset = builder.CreateString(name)
494
+ schema_fb.TensorStartShapeVector(builder, len(shape))
495
+ for dim in reversed(shape):
496
+ builder.PrependInt32(dim)
497
+ shape_offset = builder.EndVector()
498
+ schema_fb.TensorStart(builder)
499
+ schema_fb.TensorAddName(builder, name_offset)
500
+ schema_fb.TensorAddShape(builder, shape_offset)
501
+ schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
502
+ schema_fb.TensorAddBuffer(builder, buffer_idx)
503
+ tensor_offset = schema_fb.TensorEnd(builder)
504
+ return tensor_offset
505
+
506
+
507
+ def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
508
+ """Converts LoRA to FlatBuffers."""
509
+ tensors, (names, _) = _flatten_lora(lora)
510
+ # Need to manually add the "lora_" prefix to the names here. The export will
511
+ # add the prefix automatically.
512
+ names = [f"lora_{name}" for name in names]
513
+ builder = flatbuffers.Builder(4096)
514
+
515
+ # Convention to add an empty buffer in the beginning.
516
+ buffer_offsets = [_add_buffer(builder, None)]
517
+ for tensor in tensors:
518
+ buffer_offsets.append(
519
+ _add_buffer(builder, tensor.detach().type(torch.float32).numpy())
520
+ )
521
+
522
+ schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
523
+ for buffer_offset in reversed(buffer_offsets):
524
+ builder.PrependUOffsetTRelative(buffer_offset)
525
+ buffers_offset = builder.EndVector()
526
+
527
+ tensor_offsets = []
528
+ for i, (name, tensor) in enumerate(zip(names, tensors)):
529
+ # Note that the zeroth buffer is empty and reserved for the convention.
530
+ tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
531
+
532
+ schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
533
+ for tensor_offset in reversed(tensor_offsets):
534
+ builder.PrependUOffsetTRelative(tensor_offset)
535
+ tensors_offset = builder.EndVector()
536
+
537
+ string_offset = builder.CreateString("lora_params")
538
+ schema_fb.SubGraphStart(builder)
539
+ schema_fb.SubGraphAddName(builder, string_offset)
540
+ schema_fb.SubGraphAddTensors(builder, tensors_offset)
541
+ subgraph_offset = schema_fb.SubGraphEnd(builder)
542
+
543
+ schema_fb.ModelStartSubgraphsVector(builder, 1)
544
+ builder.PrependUOffsetTRelative(subgraph_offset)
545
+ subgraphs_offset = builder.EndVector()
546
+
547
+ string_offset = builder.CreateString("lora_params")
548
+ schema_fb.ModelStart(builder)
549
+ schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
550
+ schema_fb.ModelAddDescription(builder, string_offset)
551
+ schema_fb.ModelAddBuffers(builder, buffers_offset)
552
+ schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
553
+ model_offset = schema_fb.ModelEnd(builder)
554
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
555
+ flatbuffer_model = builder.Output()
556
+
557
+ return flatbuffer_model
@@ -0,0 +1,147 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A suite of tests to validate LoRA utilities."""
17
+
18
+ from ai_edge_torch.generative.layers import lora as lora_utils
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ import torch
21
+ from absl.testing import absltest as googletest
22
+ from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ class TestLora(googletest.TestCase):
26
+ """Tests for LoRA utilities."""
27
+
28
+ def test_safetensors_builder(self):
29
+ """Converts a safetensors file to a LoRA module."""
30
+
31
+ tensor_names = lora_utils.LoRATensorNames(
32
+ attn_query_w_a=(
33
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
34
+ ),
35
+ attn_query_w_b=(
36
+ "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
37
+ ),
38
+ attn_key_w_a=(
39
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
40
+ ),
41
+ attn_key_w_b=(
42
+ "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
43
+ ),
44
+ attn_value_w_a=(
45
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
46
+ ),
47
+ attn_value_w_b=(
48
+ "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
49
+ ),
50
+ attn_output_w_a=(
51
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
52
+ ),
53
+ attn_output_w_b=(
54
+ "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
55
+ ),
56
+ )
57
+
58
+ safetensors_file = resource_loader.get_path_to_datafile(
59
+ "fixtures/test_lora_rank16.safetensors"
60
+ )
61
+ config = self._get_test_config(
62
+ num_layers=1,
63
+ head_dim=8,
64
+ num_query_groups=1,
65
+ kv_cache_max_len=16,
66
+ )
67
+ lora = lora_utils.LoRA.from_safetensors(
68
+ safetensors_file,
69
+ scale=1.0,
70
+ lora_tensor_names=tensor_names,
71
+ config=config,
72
+ )
73
+ self.assertEqual(lora.get_rank(), 16)
74
+
75
+ def test_torch_export(self):
76
+ """Tests the export of the LoRA module."""
77
+
78
+ class TestModel(torch.nn.Module):
79
+
80
+ def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
81
+ x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
82
+ return x
83
+
84
+ n = 1
85
+ head_dim = 2
86
+ num_query_groups = 1
87
+ key_length = 4
88
+ config = self._get_test_config(
89
+ num_layers=n,
90
+ head_dim=head_dim,
91
+ num_query_groups=num_query_groups,
92
+ kv_cache_max_len=key_length,
93
+ )
94
+ inputs = torch.zeros((n, 1, head_dim))
95
+ lora = lora_utils.LoRA.zeros(rank=16, config=config)
96
+ model = TestModel()
97
+ exported_program = torch.export.export(model, (inputs, lora))
98
+ input_specs = exported_program.graph_signature.input_specs
99
+ # 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
100
+ # 2 for output lora.
101
+ self.assertLen(input_specs, 9)
102
+ self.assertEqual(input_specs[0].arg.name, "x")
103
+ self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
104
+ self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
105
+ self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
106
+ self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
107
+ self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
108
+ self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
109
+ self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
110
+ self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
111
+
112
+ def test_lora_tflite_serialization(self):
113
+ """Tests the serialization of the LoRA module."""
114
+ config = self._get_test_config(
115
+ num_layers=2,
116
+ head_dim=8,
117
+ num_query_groups=1,
118
+ kv_cache_max_len=16,
119
+ )
120
+ lora = lora_utils.LoRA.random(rank=16, config=config)
121
+ flatbuffer_model = lora.to_tflite()
122
+ recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
123
+ self.assertEqual(lora, recovered_lora)
124
+
125
+ def _get_test_config(
126
+ self, num_layers, head_dim, num_query_groups, kv_cache_max_len
127
+ ):
128
+ """Returns a test model config."""
129
+ attn_config = cfg.AttentionConfig(
130
+ num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
131
+ )
132
+ block_config = cfg.TransformerBlockConfig(
133
+ attn_config=attn_config, ff_config=None
134
+ )
135
+ config = cfg.ModelConfig(
136
+ kv_cache_max_len=kv_cache_max_len,
137
+ embedding_dim=head_dim,
138
+ block_configs=block_config,
139
+ num_layers=num_layers,
140
+ max_seq_len=None,
141
+ vocab_size=None,
142
+ )
143
+ return config
144
+
145
+
146
+ if __name__ == "__main__":
147
+ googletest.main()