ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +45 -37
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +23 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""LoRA weights for generative models.
|
17
|
+
|
18
|
+
The current implementation support attention only lora. Additionally, we expect
|
19
|
+
lora weights for all projections within the attention module (i.e., Q, K, V, O).
|
20
|
+
"""
|
21
|
+
|
22
|
+
import dataclasses
|
23
|
+
from typing import Any, Callable, List, Optional, Tuple
|
24
|
+
|
25
|
+
from ai_edge_torch.generative.layers import model_config
|
26
|
+
import flatbuffers
|
27
|
+
import numpy as np
|
28
|
+
import safetensors
|
29
|
+
import torch
|
30
|
+
import torch.utils._pytree as pytree
|
31
|
+
|
32
|
+
from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
|
33
|
+
|
34
|
+
_TFLITE_SCHEMA_VERSION = 3
|
35
|
+
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
36
|
+
|
37
|
+
|
38
|
+
@dataclasses.dataclass
|
39
|
+
class LoRAWeight:
|
40
|
+
"""LoRA weight per projection. The weights are pre-transposed."""
|
41
|
+
|
42
|
+
a_prime: torch.Tensor
|
43
|
+
b_prime: torch.Tensor
|
44
|
+
|
45
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
46
|
+
if not isinstance(other, LoRAWeight):
|
47
|
+
return False
|
48
|
+
if self.a_prime.shape != other.a_prime.shape:
|
49
|
+
return False
|
50
|
+
if self.b_prime.shape != other.b_prime.shape:
|
51
|
+
return False
|
52
|
+
return torch.allclose(
|
53
|
+
self.a_prime, other.a_prime, rtol=rtol, atol=atol
|
54
|
+
) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
|
55
|
+
|
56
|
+
|
57
|
+
@dataclasses.dataclass
|
58
|
+
class AttentionLoRA:
|
59
|
+
"""LoRA weights for attention module."""
|
60
|
+
|
61
|
+
query: LoRAWeight
|
62
|
+
key: LoRAWeight
|
63
|
+
value: LoRAWeight
|
64
|
+
output: LoRAWeight
|
65
|
+
|
66
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
67
|
+
if not isinstance(other, AttentionLoRA):
|
68
|
+
return False
|
69
|
+
return (
|
70
|
+
self.query.__eq__(other.query, rtol=rtol, atol=atol)
|
71
|
+
and self.key.__eq__(other.key, rtol=rtol, atol=atol)
|
72
|
+
and self.value.__eq__(other.value, rtol=rtol, atol=atol)
|
73
|
+
and self.output.__eq__(other.output, rtol=rtol, atol=atol)
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
@dataclasses.dataclass
|
78
|
+
class LoRAEntry:
|
79
|
+
"""LoRA weights for a single layer."""
|
80
|
+
|
81
|
+
attention: AttentionLoRA
|
82
|
+
|
83
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
84
|
+
if not isinstance(other, LoRAEntry):
|
85
|
+
return False
|
86
|
+
return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
|
87
|
+
|
88
|
+
|
89
|
+
@dataclasses.dataclass
|
90
|
+
class LoRATensorNames:
|
91
|
+
"""Tensor names for LoRA weights."""
|
92
|
+
|
93
|
+
attn_query_w_a: str
|
94
|
+
attn_query_w_b: str
|
95
|
+
|
96
|
+
attn_key_w_a: str
|
97
|
+
attn_key_w_b: str
|
98
|
+
|
99
|
+
attn_value_w_a: str
|
100
|
+
attn_value_w_b: str
|
101
|
+
|
102
|
+
attn_output_w_a: str
|
103
|
+
attn_output_w_b: str
|
104
|
+
|
105
|
+
|
106
|
+
@dataclasses.dataclass
|
107
|
+
class LoRA:
|
108
|
+
"""LoRA weights for all modules."""
|
109
|
+
|
110
|
+
adapters: Tuple[LoRAEntry, ...]
|
111
|
+
|
112
|
+
def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
|
113
|
+
if not isinstance(other, LoRA):
|
114
|
+
return False
|
115
|
+
if len(self.adapters) != len(other.adapters):
|
116
|
+
return False
|
117
|
+
return all(
|
118
|
+
adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
|
119
|
+
for adapter, other_adapter in zip(self.adapters, other.adapters)
|
120
|
+
)
|
121
|
+
|
122
|
+
def get_rank(self) -> int:
|
123
|
+
"""Returns the rank of the LoRA weights."""
|
124
|
+
return self.adapters[0].attention.query.a_prime.shape[1]
|
125
|
+
|
126
|
+
@classmethod
|
127
|
+
def from_safetensors(
|
128
|
+
cls,
|
129
|
+
path: str,
|
130
|
+
scale: float,
|
131
|
+
config: model_config.ModelConfig,
|
132
|
+
lora_tensor_names: LoRATensorNames,
|
133
|
+
dtype: torch.dtype = torch.float32,
|
134
|
+
) -> "LoRA":
|
135
|
+
"""Creates LoRA weights from a Hugging Face model.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
path: Path to the model.
|
139
|
+
scale: Scale factor for the LoRA weights (applied only to one of the
|
140
|
+
projections). The scaling factor depnds on the training configuration.
|
141
|
+
The common values are either `lora_alpha / rank` or `lora_alpha /
|
142
|
+
sqrt(rank)`.
|
143
|
+
config: Model configuration.
|
144
|
+
lora_tensor_names: Tensor names for the LoRA weights.
|
145
|
+
dtype: Data type of the LoRA weights. Currently only float32 is supported.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
LoRA weights for all modules.
|
149
|
+
"""
|
150
|
+
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
151
|
+
adapters = []
|
152
|
+
for i in range(config.num_layers):
|
153
|
+
attention_lora = AttentionLoRA(
|
154
|
+
query=LoRAWeight(
|
155
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
|
156
|
+
.to(dtype)
|
157
|
+
.T
|
158
|
+
* scale,
|
159
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
|
160
|
+
.to(dtype)
|
161
|
+
.T,
|
162
|
+
),
|
163
|
+
key=LoRAWeight(
|
164
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
|
165
|
+
.to(dtype)
|
166
|
+
.T
|
167
|
+
* scale,
|
168
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
|
169
|
+
.to(dtype)
|
170
|
+
.T,
|
171
|
+
),
|
172
|
+
value=LoRAWeight(
|
173
|
+
a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
|
174
|
+
.to(dtype)
|
175
|
+
.T
|
176
|
+
* scale,
|
177
|
+
b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
|
178
|
+
.to(dtype)
|
179
|
+
.T,
|
180
|
+
),
|
181
|
+
output=LoRAWeight(
|
182
|
+
a_prime=f.get_tensor(
|
183
|
+
lora_tensor_names.attn_output_w_a.format(i)
|
184
|
+
)
|
185
|
+
.to(dtype)
|
186
|
+
.T
|
187
|
+
* scale,
|
188
|
+
b_prime=f.get_tensor(
|
189
|
+
lora_tensor_names.attn_output_w_b.format(i)
|
190
|
+
)
|
191
|
+
.to(dtype)
|
192
|
+
.T,
|
193
|
+
),
|
194
|
+
)
|
195
|
+
adapters.append(LoRAEntry(attention=attention_lora))
|
196
|
+
return cls(adapters=adapters)
|
197
|
+
|
198
|
+
@classmethod
|
199
|
+
def from_flatbuffers(
|
200
|
+
cls,
|
201
|
+
flatbuffer_model: bytearray,
|
202
|
+
dtype: torch.dtype = torch.float32,
|
203
|
+
) -> "LoRA":
|
204
|
+
"""Creates LoRA weights from FlatBuffers.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
flatbuffer_model: FlatBuffers model.
|
208
|
+
dtype: Data type of the LoRA weights.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
LoRA weights for all modules.
|
212
|
+
"""
|
213
|
+
model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
|
214
|
+
model = schema_fb.ModelT.InitFromObj(model)
|
215
|
+
|
216
|
+
flat_names = []
|
217
|
+
tensors = []
|
218
|
+
for tensor in model.subgraphs[0].tensors:
|
219
|
+
name = tensor.name.decode("utf-8")
|
220
|
+
assert name.startswith("lora_")
|
221
|
+
flat_names.append(name.split("lora_")[-1])
|
222
|
+
buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
|
223
|
+
arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
|
224
|
+
torch_tensor = torch.from_numpy(arr).to(dtype)
|
225
|
+
tensors.append(torch_tensor)
|
226
|
+
|
227
|
+
return _unflatten_lora(tensors, (flat_names, []))
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
def zeros(
|
231
|
+
cls,
|
232
|
+
rank: int,
|
233
|
+
config: model_config.ModelConfig,
|
234
|
+
dtype: torch.dtype = torch.float32,
|
235
|
+
) -> "LoRA":
|
236
|
+
"""Creates LoRA weights with zeros.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
rank: Rank of the LoRA weights.
|
240
|
+
config: Model configuration.
|
241
|
+
dtype: Data type of the LoRA weights. Currently only float32 is supported.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
LoRA weights with zeros.
|
245
|
+
"""
|
246
|
+
return cls._from_tensor_generator(
|
247
|
+
tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
|
248
|
+
rank=rank,
|
249
|
+
config=config,
|
250
|
+
dtype=dtype,
|
251
|
+
)
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def random(
|
255
|
+
cls,
|
256
|
+
rank: int,
|
257
|
+
config: model_config.ModelConfig,
|
258
|
+
dtype: torch.dtype = torch.float32,
|
259
|
+
) -> "LoRA":
|
260
|
+
"""Creates LoRA weights with random values.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
rank: Rank of the LoRA weights.
|
264
|
+
config: Model configuration.
|
265
|
+
dtype: Data type of the LoRA weights.
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
LoRA weights with random values.
|
269
|
+
"""
|
270
|
+
return cls._from_tensor_generator(
|
271
|
+
tensor_generator=lambda shape, dtype: torch.randint(
|
272
|
+
low=0, high=128, size=shape, dtype=dtype
|
273
|
+
),
|
274
|
+
rank=rank,
|
275
|
+
config=config,
|
276
|
+
dtype=dtype,
|
277
|
+
)
|
278
|
+
|
279
|
+
@classmethod
|
280
|
+
def _from_tensor_generator(
|
281
|
+
cls,
|
282
|
+
tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
|
283
|
+
rank: int,
|
284
|
+
config: model_config.ModelConfig,
|
285
|
+
dtype: torch.dtype = torch.float32,
|
286
|
+
) -> "LoRA":
|
287
|
+
"""Creates LoRA weights from a tensor generator."""
|
288
|
+
adapters = []
|
289
|
+
|
290
|
+
for i in range(config.num_layers):
|
291
|
+
block_config = config.block_config(i)
|
292
|
+
q_per_kv = (
|
293
|
+
block_config.attn_config.num_heads
|
294
|
+
// block_config.attn_config.num_query_groups
|
295
|
+
)
|
296
|
+
q_out_dim = q_per_kv * block_config.attn_config.head_dim
|
297
|
+
k_out_dim = v_out_dim = block_config.attn_config.head_dim
|
298
|
+
attention_lora = AttentionLoRA(
|
299
|
+
query=LoRAWeight(
|
300
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
301
|
+
b_prime=tensor_generator((rank, q_out_dim), dtype),
|
302
|
+
),
|
303
|
+
key=LoRAWeight(
|
304
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
305
|
+
b_prime=tensor_generator((rank, k_out_dim), dtype),
|
306
|
+
),
|
307
|
+
value=LoRAWeight(
|
308
|
+
a_prime=tensor_generator((config.embedding_dim, rank), dtype),
|
309
|
+
b_prime=tensor_generator((rank, v_out_dim), dtype),
|
310
|
+
),
|
311
|
+
output=LoRAWeight(
|
312
|
+
a_prime=tensor_generator(
|
313
|
+
(
|
314
|
+
block_config.attn_config.num_heads
|
315
|
+
* block_config.attn_config.head_dim,
|
316
|
+
rank,
|
317
|
+
),
|
318
|
+
dtype,
|
319
|
+
),
|
320
|
+
b_prime=tensor_generator((rank, config.embedding_dim), dtype),
|
321
|
+
),
|
322
|
+
)
|
323
|
+
adapters.append(LoRAEntry(attention=attention_lora))
|
324
|
+
return cls(adapters=adapters)
|
325
|
+
|
326
|
+
def to_tflite(self) -> bytearray:
|
327
|
+
"""Converts LoRA to FlatBuffers."""
|
328
|
+
return _lora_to_flatbuffers(self)
|
329
|
+
|
330
|
+
|
331
|
+
def apply_lora(
|
332
|
+
x: torch.Tensor,
|
333
|
+
lora_weight: LoRAWeight,
|
334
|
+
shape: Optional[Tuple[int, ...]] = None,
|
335
|
+
) -> torch.Tensor:
|
336
|
+
"""Applies LoRA weights to a tensor.
|
337
|
+
|
338
|
+
Args:
|
339
|
+
x: Input tensor.
|
340
|
+
lora_weight: LoRA weight.
|
341
|
+
shape: Output shape. If None, the output shape is the same as the input
|
342
|
+
shape.
|
343
|
+
|
344
|
+
Returns:
|
345
|
+
Output tensor.
|
346
|
+
"""
|
347
|
+
output = torch.matmul(
|
348
|
+
torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
|
349
|
+
)
|
350
|
+
if shape is not None:
|
351
|
+
output = output.reshape(shape)
|
352
|
+
return output
|
353
|
+
|
354
|
+
|
355
|
+
def _flatten_attention_lora(
|
356
|
+
lora: AttentionLoRA, block_index: int
|
357
|
+
) -> Tuple[List[torch.Tensor], List[str]]:
|
358
|
+
"""Flattens LoRA weights for attention module."""
|
359
|
+
flattened = []
|
360
|
+
flat_names = []
|
361
|
+
flattened.append(lora.query.a_prime)
|
362
|
+
flat_names.append(f"atten_q_a_prime_weight_{block_index}")
|
363
|
+
flattened.append(lora.query.b_prime)
|
364
|
+
flat_names.append(f"atten_q_b_prime_weight_{block_index}")
|
365
|
+
flattened.append(lora.key.a_prime)
|
366
|
+
flat_names.append(f"atten_k_a_prime_weight_{block_index}")
|
367
|
+
flattened.append(lora.key.b_prime)
|
368
|
+
flat_names.append(f"atten_k_b_prime_weight_{block_index}")
|
369
|
+
flattened.append(lora.value.a_prime)
|
370
|
+
flat_names.append(f"atten_v_a_prime_weight_{block_index}")
|
371
|
+
flattened.append(lora.value.b_prime)
|
372
|
+
flat_names.append(f"atten_v_b_prime_weight_{block_index}")
|
373
|
+
flattened.append(lora.output.a_prime)
|
374
|
+
flat_names.append(f"atten_o_a_prime_weight_{block_index}")
|
375
|
+
flattened.append(lora.output.b_prime)
|
376
|
+
flat_names.append(f"atten_o_b_prime_weight_{block_index}")
|
377
|
+
return flattened, flat_names
|
378
|
+
|
379
|
+
|
380
|
+
def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
|
381
|
+
"""Flattens LoRA weights."""
|
382
|
+
flattened = []
|
383
|
+
flat_names = []
|
384
|
+
none_names = []
|
385
|
+
for i, entry in enumerate(lora.adapters):
|
386
|
+
attn_flattened, attn_flat_names = _flatten_attention_lora(
|
387
|
+
lora=entry.attention, block_index=i
|
388
|
+
)
|
389
|
+
flattened.extend(attn_flattened)
|
390
|
+
flat_names.extend(attn_flat_names)
|
391
|
+
return flattened, [flat_names, none_names]
|
392
|
+
|
393
|
+
|
394
|
+
def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
|
395
|
+
"""Flattens LoRA weights with keys."""
|
396
|
+
flattened, (flat_names, _) = _flatten_lora(lora)
|
397
|
+
return [
|
398
|
+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
399
|
+
], flat_names
|
400
|
+
|
401
|
+
|
402
|
+
def _unflatten_lora(
|
403
|
+
values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
|
404
|
+
) -> LoRA:
|
405
|
+
"""Unflattens LoRA object."""
|
406
|
+
flat_names, _ = context
|
407
|
+
names_weights = list(zip(flat_names, values))
|
408
|
+
adapters = {}
|
409
|
+
while names_weights:
|
410
|
+
name, weight = names_weights.pop(0)
|
411
|
+
block_idx = int(name.split("_")[-1])
|
412
|
+
if block_idx not in adapters:
|
413
|
+
adapters[block_idx] = LoRAEntry(
|
414
|
+
attention=AttentionLoRA(
|
415
|
+
query=LoRAWeight(
|
416
|
+
a_prime=None,
|
417
|
+
b_prime=None,
|
418
|
+
),
|
419
|
+
key=LoRAWeight(
|
420
|
+
a_prime=None,
|
421
|
+
b_prime=None,
|
422
|
+
),
|
423
|
+
value=LoRAWeight(
|
424
|
+
a_prime=None,
|
425
|
+
b_prime=None,
|
426
|
+
),
|
427
|
+
output=LoRAWeight(
|
428
|
+
a_prime=None,
|
429
|
+
b_prime=None,
|
430
|
+
),
|
431
|
+
)
|
432
|
+
)
|
433
|
+
|
434
|
+
if name.startswith("atten_"):
|
435
|
+
if "q_a_prime" in name:
|
436
|
+
adapters[block_idx].attention.query.a_prime = weight
|
437
|
+
elif "q_b_prime" in name:
|
438
|
+
adapters[block_idx].attention.query.b_prime = weight
|
439
|
+
elif "k_a_prime" in name:
|
440
|
+
adapters[block_idx].attention.key.a_prime = weight
|
441
|
+
elif "k_b_prime" in name:
|
442
|
+
adapters[block_idx].attention.key.b_prime = weight
|
443
|
+
elif "v_a_prime" in name:
|
444
|
+
adapters[block_idx].attention.value.a_prime = weight
|
445
|
+
elif "v_b_prime" in name:
|
446
|
+
adapters[block_idx].attention.value.b_prime = weight
|
447
|
+
elif "o_a_prime" in name:
|
448
|
+
adapters[block_idx].attention.output.a_prime = weight
|
449
|
+
elif "o_b_prime" in name:
|
450
|
+
adapters[block_idx].attention.output.b_prime = weight
|
451
|
+
else:
|
452
|
+
raise ValueError(f"Unsupported name: {name}")
|
453
|
+
else:
|
454
|
+
raise ValueError(f"Unsupported name: {name}")
|
455
|
+
|
456
|
+
return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
|
457
|
+
|
458
|
+
|
459
|
+
pytree.register_pytree_node(
|
460
|
+
LoRA,
|
461
|
+
_flatten_lora,
|
462
|
+
_unflatten_lora,
|
463
|
+
flatten_with_keys_fn=_flatten_lora_with_keys,
|
464
|
+
serialized_type_name="",
|
465
|
+
)
|
466
|
+
|
467
|
+
|
468
|
+
def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
|
469
|
+
"""Adds a buffer to the FlatBuffers."""
|
470
|
+
if data is not None:
|
471
|
+
assert data.dtype == np.float32
|
472
|
+
schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
|
473
|
+
for value in reversed(data.flatten().tolist()):
|
474
|
+
builder.PrependFloat32(value)
|
475
|
+
data_offset = builder.EndVector()
|
476
|
+
else:
|
477
|
+
schema_fb.BufferStartDataVector(builder, 0)
|
478
|
+
data_offset = builder.EndVector()
|
479
|
+
|
480
|
+
schema_fb.BufferStart(builder)
|
481
|
+
schema_fb.BufferAddData(builder, data_offset)
|
482
|
+
buffer_offset = schema_fb.BufferEnd(builder)
|
483
|
+
return buffer_offset
|
484
|
+
|
485
|
+
|
486
|
+
def _add_tensor(
|
487
|
+
builder: flatbuffers.Builder,
|
488
|
+
name: str,
|
489
|
+
shape: Tuple[int, ...],
|
490
|
+
buffer_idx: int,
|
491
|
+
) -> int:
|
492
|
+
"""Adds a tensor to the FlatBuffers."""
|
493
|
+
name_offset = builder.CreateString(name)
|
494
|
+
schema_fb.TensorStartShapeVector(builder, len(shape))
|
495
|
+
for dim in reversed(shape):
|
496
|
+
builder.PrependInt32(dim)
|
497
|
+
shape_offset = builder.EndVector()
|
498
|
+
schema_fb.TensorStart(builder)
|
499
|
+
schema_fb.TensorAddName(builder, name_offset)
|
500
|
+
schema_fb.TensorAddShape(builder, shape_offset)
|
501
|
+
schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
|
502
|
+
schema_fb.TensorAddBuffer(builder, buffer_idx)
|
503
|
+
tensor_offset = schema_fb.TensorEnd(builder)
|
504
|
+
return tensor_offset
|
505
|
+
|
506
|
+
|
507
|
+
def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
|
508
|
+
"""Converts LoRA to FlatBuffers."""
|
509
|
+
tensors, (names, _) = _flatten_lora(lora)
|
510
|
+
# Need to manually add the "lora_" prefix to the names here. The export will
|
511
|
+
# add the prefix automatically.
|
512
|
+
names = [f"lora_{name}" for name in names]
|
513
|
+
builder = flatbuffers.Builder(4096)
|
514
|
+
|
515
|
+
# Convention to add an empty buffer in the beginning.
|
516
|
+
buffer_offsets = [_add_buffer(builder, None)]
|
517
|
+
for tensor in tensors:
|
518
|
+
buffer_offsets.append(
|
519
|
+
_add_buffer(builder, tensor.detach().type(torch.float32).numpy())
|
520
|
+
)
|
521
|
+
|
522
|
+
schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
|
523
|
+
for buffer_offset in reversed(buffer_offsets):
|
524
|
+
builder.PrependUOffsetTRelative(buffer_offset)
|
525
|
+
buffers_offset = builder.EndVector()
|
526
|
+
|
527
|
+
tensor_offsets = []
|
528
|
+
for i, (name, tensor) in enumerate(zip(names, tensors)):
|
529
|
+
# Note that the zeroth buffer is empty and reserved for the convention.
|
530
|
+
tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
|
531
|
+
|
532
|
+
schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
|
533
|
+
for tensor_offset in reversed(tensor_offsets):
|
534
|
+
builder.PrependUOffsetTRelative(tensor_offset)
|
535
|
+
tensors_offset = builder.EndVector()
|
536
|
+
|
537
|
+
string_offset = builder.CreateString("lora_params")
|
538
|
+
schema_fb.SubGraphStart(builder)
|
539
|
+
schema_fb.SubGraphAddName(builder, string_offset)
|
540
|
+
schema_fb.SubGraphAddTensors(builder, tensors_offset)
|
541
|
+
subgraph_offset = schema_fb.SubGraphEnd(builder)
|
542
|
+
|
543
|
+
schema_fb.ModelStartSubgraphsVector(builder, 1)
|
544
|
+
builder.PrependUOffsetTRelative(subgraph_offset)
|
545
|
+
subgraphs_offset = builder.EndVector()
|
546
|
+
|
547
|
+
string_offset = builder.CreateString("lora_params")
|
548
|
+
schema_fb.ModelStart(builder)
|
549
|
+
schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
|
550
|
+
schema_fb.ModelAddDescription(builder, string_offset)
|
551
|
+
schema_fb.ModelAddBuffers(builder, buffers_offset)
|
552
|
+
schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
|
553
|
+
model_offset = schema_fb.ModelEnd(builder)
|
554
|
+
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
555
|
+
flatbuffer_model = builder.Output()
|
556
|
+
|
557
|
+
return flatbuffer_model
|
@@ -17,8 +17,8 @@
|
|
17
17
|
|
18
18
|
import dataclasses
|
19
19
|
import enum
|
20
|
-
from typing import Optional, Sequence, Union
|
21
|
-
|
20
|
+
from typing import Callable, Optional, Sequence, Union
|
21
|
+
from ai_edge_torch.generative.layers import rotary_position_embedding
|
22
22
|
|
23
23
|
@enum.unique
|
24
24
|
class ActivationType(enum.Enum):
|
@@ -218,6 +218,10 @@ class ModelConfig:
|
|
218
218
|
# Softcap on the model output logits.
|
219
219
|
final_logit_softcap: Optional[float] = None
|
220
220
|
|
221
|
+
# The function to call to create the RoPE sin and cos vectors during the
|
222
|
+
# forward pass. Defaults to a standard implementation.
|
223
|
+
build_rope: Callable = rotary_position_embedding.build_rope
|
224
|
+
|
221
225
|
@property
|
222
226
|
def kv_cache_max(self) -> int:
|
223
227
|
if self.kv_cache_max_len > 0:
|
@@ -32,57 +32,63 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1 = x
|
36
|
-
|
37
|
-
|
38
|
-
roped = (
|
35
|
+
x1, x2 = torch.split(x, head_size // 2, dim=-1)
|
36
|
+
left = x1 * cos - x2 * sin
|
37
|
+
right = x2 * cos + x1 * sin
|
38
|
+
roped = torch.cat([left, right], dim=-1)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
43
|
-
q: torch.Tensor,
|
44
|
-
k: torch.Tensor,
|
42
|
+
def build_rope(
|
45
43
|
input_pos: torch.Tensor,
|
46
44
|
n_elem: int,
|
45
|
+
head_dim: int,
|
47
46
|
base: int = 10_000,
|
48
47
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49
|
-
"""Computes rotary positional embedding
|
48
|
+
"""Computes rotary positional embedding cosine and sine tensors.
|
50
49
|
|
51
50
|
Args:
|
52
|
-
q: the query tensor.
|
53
|
-
k: the key tensor.
|
54
51
|
input_pos: the sequence indices for the query and key
|
55
52
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
+
base: the base of the exponentiated value for RoPE.
|
56
54
|
|
57
55
|
Returns:
|
58
|
-
|
56
|
+
cos, sin tensors
|
59
57
|
"""
|
60
58
|
|
61
59
|
if n_elem <= 0:
|
62
|
-
return
|
60
|
+
return None, None
|
63
61
|
|
64
|
-
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
65
62
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
66
|
-
|
63
|
+
head_dim // 2, dtype=torch.float32
|
67
64
|
)
|
68
65
|
timescale = float(base) ** freq_exponents
|
69
66
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
70
67
|
0
|
71
68
|
).unsqueeze(0)
|
72
|
-
cos = torch.cos(radians)
|
73
|
-
sin = torch.sin(radians)
|
69
|
+
cos = torch.cos(radians)
|
70
|
+
sin = torch.sin(radians)
|
71
|
+
return cos, sin
|
72
|
+
|
74
73
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
74
|
+
def apply_rope_inline(
|
75
|
+
q: torch.Tensor,
|
76
|
+
k: torch.Tensor,
|
77
|
+
cos: torch.Tensor,
|
78
|
+
sin: torch.Tensor,
|
79
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
80
|
+
"""Computes rotary positional embedding inline for a query and key.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
q: the query tensor.
|
84
|
+
k: the key tensor.
|
85
|
+
cos: the cosine tensor.
|
86
|
+
sin: the sine tensor.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
output the RoPE'd query and key.
|
90
|
+
"""
|
85
91
|
|
86
|
-
q_roped =
|
87
|
-
k_roped =
|
92
|
+
q_roped = apply_rope(q, cos, sin)
|
93
|
+
k_roped = apply_rope(k, cos, sin)
|
88
94
|
return q_roped, k_roped
|