ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  6. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  7. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  8. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  9. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  10. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  13. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  14. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  15. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  17. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  18. ai_edge_torch/generative/layers/attention.py +45 -37
  19. ai_edge_torch/generative/layers/lora.py +557 -0
  20. ai_edge_torch/generative/layers/model_config.py +6 -2
  21. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  22. ai_edge_torch/generative/test/test_lora.py +147 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  24. ai_edge_torch/generative/utilities/converter.py +100 -47
  25. ai_edge_torch/generative/utilities/model_builder.py +23 -14
  26. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  27. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  28. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  29. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  30. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  31. ai_edge_torch/odml_torch/export.py +6 -2
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  33. ai_edge_torch/version.py +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
  36. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  37. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  38. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """LoRA weights for generative models.
17
+
18
+ The current implementation support attention only lora. Additionally, we expect
19
+ lora weights for all projections within the attention module (i.e., Q, K, V, O).
20
+ """
21
+
22
+ import dataclasses
23
+ from typing import Any, Callable, List, Optional, Tuple
24
+
25
+ from ai_edge_torch.generative.layers import model_config
26
+ import flatbuffers
27
+ import numpy as np
28
+ import safetensors
29
+ import torch
30
+ import torch.utils._pytree as pytree
31
+
32
+ from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
33
+
34
+ _TFLITE_SCHEMA_VERSION = 3
35
+ _TFLITE_FILE_IDENTIFIER = b"TFL3"
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class LoRAWeight:
40
+ """LoRA weight per projection. The weights are pre-transposed."""
41
+
42
+ a_prime: torch.Tensor
43
+ b_prime: torch.Tensor
44
+
45
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
46
+ if not isinstance(other, LoRAWeight):
47
+ return False
48
+ if self.a_prime.shape != other.a_prime.shape:
49
+ return False
50
+ if self.b_prime.shape != other.b_prime.shape:
51
+ return False
52
+ return torch.allclose(
53
+ self.a_prime, other.a_prime, rtol=rtol, atol=atol
54
+ ) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
55
+
56
+
57
+ @dataclasses.dataclass
58
+ class AttentionLoRA:
59
+ """LoRA weights for attention module."""
60
+
61
+ query: LoRAWeight
62
+ key: LoRAWeight
63
+ value: LoRAWeight
64
+ output: LoRAWeight
65
+
66
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
67
+ if not isinstance(other, AttentionLoRA):
68
+ return False
69
+ return (
70
+ self.query.__eq__(other.query, rtol=rtol, atol=atol)
71
+ and self.key.__eq__(other.key, rtol=rtol, atol=atol)
72
+ and self.value.__eq__(other.value, rtol=rtol, atol=atol)
73
+ and self.output.__eq__(other.output, rtol=rtol, atol=atol)
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class LoRAEntry:
79
+ """LoRA weights for a single layer."""
80
+
81
+ attention: AttentionLoRA
82
+
83
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
84
+ if not isinstance(other, LoRAEntry):
85
+ return False
86
+ return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
87
+
88
+
89
+ @dataclasses.dataclass
90
+ class LoRATensorNames:
91
+ """Tensor names for LoRA weights."""
92
+
93
+ attn_query_w_a: str
94
+ attn_query_w_b: str
95
+
96
+ attn_key_w_a: str
97
+ attn_key_w_b: str
98
+
99
+ attn_value_w_a: str
100
+ attn_value_w_b: str
101
+
102
+ attn_output_w_a: str
103
+ attn_output_w_b: str
104
+
105
+
106
+ @dataclasses.dataclass
107
+ class LoRA:
108
+ """LoRA weights for all modules."""
109
+
110
+ adapters: Tuple[LoRAEntry, ...]
111
+
112
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
113
+ if not isinstance(other, LoRA):
114
+ return False
115
+ if len(self.adapters) != len(other.adapters):
116
+ return False
117
+ return all(
118
+ adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
119
+ for adapter, other_adapter in zip(self.adapters, other.adapters)
120
+ )
121
+
122
+ def get_rank(self) -> int:
123
+ """Returns the rank of the LoRA weights."""
124
+ return self.adapters[0].attention.query.a_prime.shape[1]
125
+
126
+ @classmethod
127
+ def from_safetensors(
128
+ cls,
129
+ path: str,
130
+ scale: float,
131
+ config: model_config.ModelConfig,
132
+ lora_tensor_names: LoRATensorNames,
133
+ dtype: torch.dtype = torch.float32,
134
+ ) -> "LoRA":
135
+ """Creates LoRA weights from a Hugging Face model.
136
+
137
+ Args:
138
+ path: Path to the model.
139
+ scale: Scale factor for the LoRA weights (applied only to one of the
140
+ projections). The scaling factor depnds on the training configuration.
141
+ The common values are either `lora_alpha / rank` or `lora_alpha /
142
+ sqrt(rank)`.
143
+ config: Model configuration.
144
+ lora_tensor_names: Tensor names for the LoRA weights.
145
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
146
+
147
+ Returns:
148
+ LoRA weights for all modules.
149
+ """
150
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
151
+ adapters = []
152
+ for i in range(config.num_layers):
153
+ attention_lora = AttentionLoRA(
154
+ query=LoRAWeight(
155
+ a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
156
+ .to(dtype)
157
+ .T
158
+ * scale,
159
+ b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
160
+ .to(dtype)
161
+ .T,
162
+ ),
163
+ key=LoRAWeight(
164
+ a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
165
+ .to(dtype)
166
+ .T
167
+ * scale,
168
+ b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
169
+ .to(dtype)
170
+ .T,
171
+ ),
172
+ value=LoRAWeight(
173
+ a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
174
+ .to(dtype)
175
+ .T
176
+ * scale,
177
+ b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
178
+ .to(dtype)
179
+ .T,
180
+ ),
181
+ output=LoRAWeight(
182
+ a_prime=f.get_tensor(
183
+ lora_tensor_names.attn_output_w_a.format(i)
184
+ )
185
+ .to(dtype)
186
+ .T
187
+ * scale,
188
+ b_prime=f.get_tensor(
189
+ lora_tensor_names.attn_output_w_b.format(i)
190
+ )
191
+ .to(dtype)
192
+ .T,
193
+ ),
194
+ )
195
+ adapters.append(LoRAEntry(attention=attention_lora))
196
+ return cls(adapters=adapters)
197
+
198
+ @classmethod
199
+ def from_flatbuffers(
200
+ cls,
201
+ flatbuffer_model: bytearray,
202
+ dtype: torch.dtype = torch.float32,
203
+ ) -> "LoRA":
204
+ """Creates LoRA weights from FlatBuffers.
205
+
206
+ Args:
207
+ flatbuffer_model: FlatBuffers model.
208
+ dtype: Data type of the LoRA weights.
209
+
210
+ Returns:
211
+ LoRA weights for all modules.
212
+ """
213
+ model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
214
+ model = schema_fb.ModelT.InitFromObj(model)
215
+
216
+ flat_names = []
217
+ tensors = []
218
+ for tensor in model.subgraphs[0].tensors:
219
+ name = tensor.name.decode("utf-8")
220
+ assert name.startswith("lora_")
221
+ flat_names.append(name.split("lora_")[-1])
222
+ buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
223
+ arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
224
+ torch_tensor = torch.from_numpy(arr).to(dtype)
225
+ tensors.append(torch_tensor)
226
+
227
+ return _unflatten_lora(tensors, (flat_names, []))
228
+
229
+ @classmethod
230
+ def zeros(
231
+ cls,
232
+ rank: int,
233
+ config: model_config.ModelConfig,
234
+ dtype: torch.dtype = torch.float32,
235
+ ) -> "LoRA":
236
+ """Creates LoRA weights with zeros.
237
+
238
+ Args:
239
+ rank: Rank of the LoRA weights.
240
+ config: Model configuration.
241
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
242
+
243
+ Returns:
244
+ LoRA weights with zeros.
245
+ """
246
+ return cls._from_tensor_generator(
247
+ tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
248
+ rank=rank,
249
+ config=config,
250
+ dtype=dtype,
251
+ )
252
+
253
+ @classmethod
254
+ def random(
255
+ cls,
256
+ rank: int,
257
+ config: model_config.ModelConfig,
258
+ dtype: torch.dtype = torch.float32,
259
+ ) -> "LoRA":
260
+ """Creates LoRA weights with random values.
261
+
262
+ Args:
263
+ rank: Rank of the LoRA weights.
264
+ config: Model configuration.
265
+ dtype: Data type of the LoRA weights.
266
+
267
+ Returns:
268
+ LoRA weights with random values.
269
+ """
270
+ return cls._from_tensor_generator(
271
+ tensor_generator=lambda shape, dtype: torch.randint(
272
+ low=0, high=128, size=shape, dtype=dtype
273
+ ),
274
+ rank=rank,
275
+ config=config,
276
+ dtype=dtype,
277
+ )
278
+
279
+ @classmethod
280
+ def _from_tensor_generator(
281
+ cls,
282
+ tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
283
+ rank: int,
284
+ config: model_config.ModelConfig,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> "LoRA":
287
+ """Creates LoRA weights from a tensor generator."""
288
+ adapters = []
289
+
290
+ for i in range(config.num_layers):
291
+ block_config = config.block_config(i)
292
+ q_per_kv = (
293
+ block_config.attn_config.num_heads
294
+ // block_config.attn_config.num_query_groups
295
+ )
296
+ q_out_dim = q_per_kv * block_config.attn_config.head_dim
297
+ k_out_dim = v_out_dim = block_config.attn_config.head_dim
298
+ attention_lora = AttentionLoRA(
299
+ query=LoRAWeight(
300
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
301
+ b_prime=tensor_generator((rank, q_out_dim), dtype),
302
+ ),
303
+ key=LoRAWeight(
304
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
305
+ b_prime=tensor_generator((rank, k_out_dim), dtype),
306
+ ),
307
+ value=LoRAWeight(
308
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
309
+ b_prime=tensor_generator((rank, v_out_dim), dtype),
310
+ ),
311
+ output=LoRAWeight(
312
+ a_prime=tensor_generator(
313
+ (
314
+ block_config.attn_config.num_heads
315
+ * block_config.attn_config.head_dim,
316
+ rank,
317
+ ),
318
+ dtype,
319
+ ),
320
+ b_prime=tensor_generator((rank, config.embedding_dim), dtype),
321
+ ),
322
+ )
323
+ adapters.append(LoRAEntry(attention=attention_lora))
324
+ return cls(adapters=adapters)
325
+
326
+ def to_tflite(self) -> bytearray:
327
+ """Converts LoRA to FlatBuffers."""
328
+ return _lora_to_flatbuffers(self)
329
+
330
+
331
+ def apply_lora(
332
+ x: torch.Tensor,
333
+ lora_weight: LoRAWeight,
334
+ shape: Optional[Tuple[int, ...]] = None,
335
+ ) -> torch.Tensor:
336
+ """Applies LoRA weights to a tensor.
337
+
338
+ Args:
339
+ x: Input tensor.
340
+ lora_weight: LoRA weight.
341
+ shape: Output shape. If None, the output shape is the same as the input
342
+ shape.
343
+
344
+ Returns:
345
+ Output tensor.
346
+ """
347
+ output = torch.matmul(
348
+ torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
349
+ )
350
+ if shape is not None:
351
+ output = output.reshape(shape)
352
+ return output
353
+
354
+
355
+ def _flatten_attention_lora(
356
+ lora: AttentionLoRA, block_index: int
357
+ ) -> Tuple[List[torch.Tensor], List[str]]:
358
+ """Flattens LoRA weights for attention module."""
359
+ flattened = []
360
+ flat_names = []
361
+ flattened.append(lora.query.a_prime)
362
+ flat_names.append(f"atten_q_a_prime_weight_{block_index}")
363
+ flattened.append(lora.query.b_prime)
364
+ flat_names.append(f"atten_q_b_prime_weight_{block_index}")
365
+ flattened.append(lora.key.a_prime)
366
+ flat_names.append(f"atten_k_a_prime_weight_{block_index}")
367
+ flattened.append(lora.key.b_prime)
368
+ flat_names.append(f"atten_k_b_prime_weight_{block_index}")
369
+ flattened.append(lora.value.a_prime)
370
+ flat_names.append(f"atten_v_a_prime_weight_{block_index}")
371
+ flattened.append(lora.value.b_prime)
372
+ flat_names.append(f"atten_v_b_prime_weight_{block_index}")
373
+ flattened.append(lora.output.a_prime)
374
+ flat_names.append(f"atten_o_a_prime_weight_{block_index}")
375
+ flattened.append(lora.output.b_prime)
376
+ flat_names.append(f"atten_o_b_prime_weight_{block_index}")
377
+ return flattened, flat_names
378
+
379
+
380
+ def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
381
+ """Flattens LoRA weights."""
382
+ flattened = []
383
+ flat_names = []
384
+ none_names = []
385
+ for i, entry in enumerate(lora.adapters):
386
+ attn_flattened, attn_flat_names = _flatten_attention_lora(
387
+ lora=entry.attention, block_index=i
388
+ )
389
+ flattened.extend(attn_flattened)
390
+ flat_names.extend(attn_flat_names)
391
+ return flattened, [flat_names, none_names]
392
+
393
+
394
+ def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
395
+ """Flattens LoRA weights with keys."""
396
+ flattened, (flat_names, _) = _flatten_lora(lora)
397
+ return [
398
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
399
+ ], flat_names
400
+
401
+
402
+ def _unflatten_lora(
403
+ values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
404
+ ) -> LoRA:
405
+ """Unflattens LoRA object."""
406
+ flat_names, _ = context
407
+ names_weights = list(zip(flat_names, values))
408
+ adapters = {}
409
+ while names_weights:
410
+ name, weight = names_weights.pop(0)
411
+ block_idx = int(name.split("_")[-1])
412
+ if block_idx not in adapters:
413
+ adapters[block_idx] = LoRAEntry(
414
+ attention=AttentionLoRA(
415
+ query=LoRAWeight(
416
+ a_prime=None,
417
+ b_prime=None,
418
+ ),
419
+ key=LoRAWeight(
420
+ a_prime=None,
421
+ b_prime=None,
422
+ ),
423
+ value=LoRAWeight(
424
+ a_prime=None,
425
+ b_prime=None,
426
+ ),
427
+ output=LoRAWeight(
428
+ a_prime=None,
429
+ b_prime=None,
430
+ ),
431
+ )
432
+ )
433
+
434
+ if name.startswith("atten_"):
435
+ if "q_a_prime" in name:
436
+ adapters[block_idx].attention.query.a_prime = weight
437
+ elif "q_b_prime" in name:
438
+ adapters[block_idx].attention.query.b_prime = weight
439
+ elif "k_a_prime" in name:
440
+ adapters[block_idx].attention.key.a_prime = weight
441
+ elif "k_b_prime" in name:
442
+ adapters[block_idx].attention.key.b_prime = weight
443
+ elif "v_a_prime" in name:
444
+ adapters[block_idx].attention.value.a_prime = weight
445
+ elif "v_b_prime" in name:
446
+ adapters[block_idx].attention.value.b_prime = weight
447
+ elif "o_a_prime" in name:
448
+ adapters[block_idx].attention.output.a_prime = weight
449
+ elif "o_b_prime" in name:
450
+ adapters[block_idx].attention.output.b_prime = weight
451
+ else:
452
+ raise ValueError(f"Unsupported name: {name}")
453
+ else:
454
+ raise ValueError(f"Unsupported name: {name}")
455
+
456
+ return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
457
+
458
+
459
+ pytree.register_pytree_node(
460
+ LoRA,
461
+ _flatten_lora,
462
+ _unflatten_lora,
463
+ flatten_with_keys_fn=_flatten_lora_with_keys,
464
+ serialized_type_name="",
465
+ )
466
+
467
+
468
+ def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
469
+ """Adds a buffer to the FlatBuffers."""
470
+ if data is not None:
471
+ assert data.dtype == np.float32
472
+ schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
473
+ for value in reversed(data.flatten().tolist()):
474
+ builder.PrependFloat32(value)
475
+ data_offset = builder.EndVector()
476
+ else:
477
+ schema_fb.BufferStartDataVector(builder, 0)
478
+ data_offset = builder.EndVector()
479
+
480
+ schema_fb.BufferStart(builder)
481
+ schema_fb.BufferAddData(builder, data_offset)
482
+ buffer_offset = schema_fb.BufferEnd(builder)
483
+ return buffer_offset
484
+
485
+
486
+ def _add_tensor(
487
+ builder: flatbuffers.Builder,
488
+ name: str,
489
+ shape: Tuple[int, ...],
490
+ buffer_idx: int,
491
+ ) -> int:
492
+ """Adds a tensor to the FlatBuffers."""
493
+ name_offset = builder.CreateString(name)
494
+ schema_fb.TensorStartShapeVector(builder, len(shape))
495
+ for dim in reversed(shape):
496
+ builder.PrependInt32(dim)
497
+ shape_offset = builder.EndVector()
498
+ schema_fb.TensorStart(builder)
499
+ schema_fb.TensorAddName(builder, name_offset)
500
+ schema_fb.TensorAddShape(builder, shape_offset)
501
+ schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
502
+ schema_fb.TensorAddBuffer(builder, buffer_idx)
503
+ tensor_offset = schema_fb.TensorEnd(builder)
504
+ return tensor_offset
505
+
506
+
507
+ def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
508
+ """Converts LoRA to FlatBuffers."""
509
+ tensors, (names, _) = _flatten_lora(lora)
510
+ # Need to manually add the "lora_" prefix to the names here. The export will
511
+ # add the prefix automatically.
512
+ names = [f"lora_{name}" for name in names]
513
+ builder = flatbuffers.Builder(4096)
514
+
515
+ # Convention to add an empty buffer in the beginning.
516
+ buffer_offsets = [_add_buffer(builder, None)]
517
+ for tensor in tensors:
518
+ buffer_offsets.append(
519
+ _add_buffer(builder, tensor.detach().type(torch.float32).numpy())
520
+ )
521
+
522
+ schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
523
+ for buffer_offset in reversed(buffer_offsets):
524
+ builder.PrependUOffsetTRelative(buffer_offset)
525
+ buffers_offset = builder.EndVector()
526
+
527
+ tensor_offsets = []
528
+ for i, (name, tensor) in enumerate(zip(names, tensors)):
529
+ # Note that the zeroth buffer is empty and reserved for the convention.
530
+ tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
531
+
532
+ schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
533
+ for tensor_offset in reversed(tensor_offsets):
534
+ builder.PrependUOffsetTRelative(tensor_offset)
535
+ tensors_offset = builder.EndVector()
536
+
537
+ string_offset = builder.CreateString("lora_params")
538
+ schema_fb.SubGraphStart(builder)
539
+ schema_fb.SubGraphAddName(builder, string_offset)
540
+ schema_fb.SubGraphAddTensors(builder, tensors_offset)
541
+ subgraph_offset = schema_fb.SubGraphEnd(builder)
542
+
543
+ schema_fb.ModelStartSubgraphsVector(builder, 1)
544
+ builder.PrependUOffsetTRelative(subgraph_offset)
545
+ subgraphs_offset = builder.EndVector()
546
+
547
+ string_offset = builder.CreateString("lora_params")
548
+ schema_fb.ModelStart(builder)
549
+ schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
550
+ schema_fb.ModelAddDescription(builder, string_offset)
551
+ schema_fb.ModelAddBuffers(builder, buffers_offset)
552
+ schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
553
+ model_offset = schema_fb.ModelEnd(builder)
554
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
555
+ flatbuffer_model = builder.Output()
556
+
557
+ return flatbuffer_model
@@ -17,8 +17,8 @@
17
17
 
18
18
  import dataclasses
19
19
  import enum
20
- from typing import Optional, Sequence, Union
21
-
20
+ from typing import Callable, Optional, Sequence, Union
21
+ from ai_edge_torch.generative.layers import rotary_position_embedding
22
22
 
23
23
  @enum.unique
24
24
  class ActivationType(enum.Enum):
@@ -218,6 +218,10 @@ class ModelConfig:
218
218
  # Softcap on the model output logits.
219
219
  final_logit_softcap: Optional[float] = None
220
220
 
221
+ # The function to call to create the RoPE sin and cos vectors during the
222
+ # forward pass. Defaults to a standard implementation.
223
+ build_rope: Callable = rotary_position_embedding.build_rope
224
+
221
225
  @property
222
226
  def kv_cache_max(self) -> int:
223
227
  if self.kv_cache_max_len > 0:
@@ -32,57 +32,63 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
- x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
- rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
- roped = (x * cos) + (rotated * sin)
35
+ x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
+ left = x1 * cos - x2 * sin
37
+ right = x2 * cos + x1 * sin
38
+ roped = torch.cat([left, right], dim=-1)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def apply_rope_inline(
43
- q: torch.Tensor,
44
- k: torch.Tensor,
42
+ def build_rope(
45
43
  input_pos: torch.Tensor,
46
44
  n_elem: int,
45
+ head_dim: int,
47
46
  base: int = 10_000,
48
47
  ) -> Tuple[torch.Tensor, torch.Tensor]:
49
- """Computes rotary positional embedding inline for a query and key.
48
+ """Computes rotary positional embedding cosine and sine tensors.
50
49
 
51
50
  Args:
52
- q: the query tensor.
53
- k: the key tensor.
54
51
  input_pos: the sequence indices for the query and key
55
52
  n_elem: number of elements of the head dimension for RoPE computation
53
+ base: the base of the exponentiated value for RoPE.
56
54
 
57
55
  Returns:
58
- output the RoPE'd query and key.
56
+ cos, sin tensors
59
57
  """
60
58
 
61
59
  if n_elem <= 0:
62
- return q, k
60
+ return None, None
63
61
 
64
- theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
65
62
  freq_exponents = (2.0 / n_elem) * torch.arange(
66
- q.shape[-1] // 2, dtype=torch.float32
63
+ head_dim // 2, dtype=torch.float32
67
64
  )
68
65
  timescale = float(base) ** freq_exponents
69
66
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
70
67
  0
71
68
  ).unsqueeze(0)
72
- cos = torch.cos(radians).type_as(q)
73
- sin = torch.sin(radians).type_as(q)
69
+ cos = torch.cos(radians)
70
+ sin = torch.sin(radians)
71
+ return cos, sin
72
+
74
73
 
75
- def apply(x, sin, cos):
76
- x = x.transpose(1, 2)
77
- b, h, s, d = x.shape
78
- ans = torch.split(x, d // 2, dim=-1)
79
- x1, x2 = ans
80
- left = x1 * cos - x2 * sin
81
- right = x2 * cos + x1 * sin
82
- res = torch.cat([left, right], dim=-1)
83
- res = res.transpose(1, 2)
84
- return res
74
+ def apply_rope_inline(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ cos: torch.Tensor,
78
+ sin: torch.Tensor,
79
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ """Computes rotary positional embedding inline for a query and key.
81
+
82
+ Args:
83
+ q: the query tensor.
84
+ k: the key tensor.
85
+ cos: the cosine tensor.
86
+ sin: the sine tensor.
87
+
88
+ Returns:
89
+ output the RoPE'd query and key.
90
+ """
85
91
 
86
- q_roped = apply(q, sin, cos)
87
- k_roped = apply(k, sin, cos)
92
+ q_roped = apply_rope(q, cos, sin)
93
+ k_roped = apply_rope(k, cos, sin)
88
94
  return q_roped, k_roped