ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_config.py +26 -9
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +70 -12
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/normalization.py +2 -50
- ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +21 -16
- ai_edge_torch/generative/utilities/verifier.py +4 -4
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -32,64 +32,57 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1
|
36
|
-
|
37
|
-
|
38
|
-
roped =
|
35
|
+
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
|
36
|
+
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
|
37
|
+
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
38
|
+
roped = (x * cos) + (rotated * sin)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
42
|
+
def apply_rope_inline(
|
43
|
+
q: torch.Tensor,
|
44
|
+
k: torch.Tensor,
|
43
45
|
input_pos: torch.Tensor,
|
44
46
|
n_elem: int,
|
45
|
-
head_dim: int,
|
46
47
|
base: int = 10_000,
|
47
48
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
48
|
-
"""Computes rotary positional embedding
|
49
|
+
"""Computes rotary positional embedding inline for a query and key.
|
49
50
|
|
50
51
|
Args:
|
52
|
+
q: the query tensor.
|
53
|
+
k: the key tensor.
|
51
54
|
input_pos: the sequence indices for the query and key
|
52
55
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
-
base: the base of the exponentiated value for RoPE.
|
54
56
|
|
55
57
|
Returns:
|
56
|
-
|
58
|
+
output the RoPE'd query and key.
|
57
59
|
"""
|
58
60
|
|
59
61
|
if n_elem <= 0:
|
60
|
-
return
|
62
|
+
return q, k
|
61
63
|
|
62
64
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
63
65
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
64
|
-
|
66
|
+
q.shape[-1] // 2, dtype=torch.float32
|
65
67
|
)
|
66
68
|
timescale = float(base) ** freq_exponents
|
67
69
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
68
70
|
0
|
69
71
|
).unsqueeze(0)
|
70
|
-
cos = torch.cos(radians)
|
71
|
-
sin = torch.sin(radians)
|
72
|
-
return cos, sin
|
73
|
-
|
72
|
+
cos = torch.cos(radians).type_as(q)
|
73
|
+
sin = torch.sin(radians).type_as(q)
|
74
74
|
|
75
|
-
def
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
k: the key tensor.
|
86
|
-
cos: the cosine tensor.
|
87
|
-
sin: the sine tensor.
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
output the RoPE'd query and key.
|
91
|
-
"""
|
75
|
+
def apply(x, sin, cos):
|
76
|
+
x = x.transpose(1, 2)
|
77
|
+
b, h, s, d = x.shape
|
78
|
+
ans = torch.split(x, d // 2, dim=-1)
|
79
|
+
x1, x2 = ans
|
80
|
+
left = x1 * cos - x2 * sin
|
81
|
+
right = x2 * cos + x1 * sin
|
82
|
+
res = torch.cat([left, right], dim=-1)
|
83
|
+
res = res.transpose(1, 2)
|
84
|
+
return res
|
92
85
|
|
93
|
-
q_roped =
|
94
|
-
k_roped =
|
86
|
+
q_roped = apply(q, sin, cos)
|
87
|
+
k_roped = apply(k, sin, cos)
|
95
88
|
return q_roped, k_roped
|
@@ -0,0 +1,147 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""A suite of tests to validate LoRA utilities."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
import torch
|
21
|
+
from absl.testing import absltest as googletest
|
22
|
+
from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
|
23
|
+
|
24
|
+
|
25
|
+
class TestLora(googletest.TestCase):
|
26
|
+
"""Tests for LoRA utilities."""
|
27
|
+
|
28
|
+
def test_safetensors_builder(self):
|
29
|
+
"""Converts a safetensors file to a LoRA module."""
|
30
|
+
|
31
|
+
tensor_names = lora_utils.LoRATensorNames(
|
32
|
+
attn_query_w_a=(
|
33
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
|
34
|
+
),
|
35
|
+
attn_query_w_b=(
|
36
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
|
37
|
+
),
|
38
|
+
attn_key_w_a=(
|
39
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
|
40
|
+
),
|
41
|
+
attn_key_w_b=(
|
42
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
|
43
|
+
),
|
44
|
+
attn_value_w_a=(
|
45
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
|
46
|
+
),
|
47
|
+
attn_value_w_b=(
|
48
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
|
49
|
+
),
|
50
|
+
attn_output_w_a=(
|
51
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
|
52
|
+
),
|
53
|
+
attn_output_w_b=(
|
54
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
|
55
|
+
),
|
56
|
+
)
|
57
|
+
|
58
|
+
safetensors_file = resource_loader.get_path_to_datafile(
|
59
|
+
"fixtures/test_lora_rank16.safetensors"
|
60
|
+
)
|
61
|
+
config = self._get_test_config(
|
62
|
+
num_layers=1,
|
63
|
+
head_dim=8,
|
64
|
+
num_query_groups=1,
|
65
|
+
kv_cache_max_len=16,
|
66
|
+
)
|
67
|
+
lora = lora_utils.LoRA.from_safetensors(
|
68
|
+
safetensors_file,
|
69
|
+
scale=1.0,
|
70
|
+
lora_tensor_names=tensor_names,
|
71
|
+
config=config,
|
72
|
+
)
|
73
|
+
self.assertEqual(lora.get_rank(), 16)
|
74
|
+
|
75
|
+
def test_torch_export(self):
|
76
|
+
"""Tests the export of the LoRA module."""
|
77
|
+
|
78
|
+
class TestModel(torch.nn.Module):
|
79
|
+
|
80
|
+
def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
|
81
|
+
x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
|
82
|
+
return x
|
83
|
+
|
84
|
+
n = 1
|
85
|
+
head_dim = 2
|
86
|
+
num_query_groups = 1
|
87
|
+
key_length = 4
|
88
|
+
config = self._get_test_config(
|
89
|
+
num_layers=n,
|
90
|
+
head_dim=head_dim,
|
91
|
+
num_query_groups=num_query_groups,
|
92
|
+
kv_cache_max_len=key_length,
|
93
|
+
)
|
94
|
+
inputs = torch.zeros((n, 1, head_dim))
|
95
|
+
lora = lora_utils.LoRA.zeros(rank=16, config=config)
|
96
|
+
model = TestModel()
|
97
|
+
exported_program = torch.export.export(model, (inputs, lora))
|
98
|
+
input_specs = exported_program.graph_signature.input_specs
|
99
|
+
# 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
|
100
|
+
# 2 for output lora.
|
101
|
+
self.assertLen(input_specs, 9)
|
102
|
+
self.assertEqual(input_specs[0].arg.name, "x")
|
103
|
+
self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
|
104
|
+
self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
|
105
|
+
self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
|
106
|
+
self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
|
107
|
+
self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
|
108
|
+
self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
|
109
|
+
self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
|
110
|
+
self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
|
111
|
+
|
112
|
+
def test_lora_tflite_serialization(self):
|
113
|
+
"""Tests the serialization of the LoRA module."""
|
114
|
+
config = self._get_test_config(
|
115
|
+
num_layers=2,
|
116
|
+
head_dim=8,
|
117
|
+
num_query_groups=1,
|
118
|
+
kv_cache_max_len=16,
|
119
|
+
)
|
120
|
+
lora = lora_utils.LoRA.random(rank=16, config=config)
|
121
|
+
flatbuffer_model = lora.to_tflite()
|
122
|
+
recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
|
123
|
+
self.assertEqual(lora, recovered_lora)
|
124
|
+
|
125
|
+
def _get_test_config(
|
126
|
+
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
127
|
+
):
|
128
|
+
"""Returns a test model config."""
|
129
|
+
attn_config = cfg.AttentionConfig(
|
130
|
+
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
131
|
+
)
|
132
|
+
block_config = cfg.TransformerBlockConfig(
|
133
|
+
attn_config=attn_config, ff_config=None
|
134
|
+
)
|
135
|
+
config = cfg.ModelConfig(
|
136
|
+
kv_cache_max_len=kv_cache_max_len,
|
137
|
+
embedding_dim=head_dim,
|
138
|
+
block_configs=block_config,
|
139
|
+
num_layers=num_layers,
|
140
|
+
max_seq_len=None,
|
141
|
+
vocab_size=None,
|
142
|
+
)
|
143
|
+
return config
|
144
|
+
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
googletest.main()
|
@@ -15,16 +15,15 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
|
19
|
-
from typing import
|
20
|
-
|
18
|
+
import os
|
19
|
+
from typing import Optional, Union
|
21
20
|
from ai_edge_torch._convert import converter as converter_utils
|
21
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
22
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
25
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
26
|
import torch
|
27
|
-
import torch.nn as nn
|
28
27
|
|
29
28
|
|
30
29
|
class ExportableModule(torch.nn.Module):
|
@@ -41,11 +40,13 @@ class ExportableModule(torch.nn.Module):
|
|
41
40
|
|
42
41
|
def convert_to_tflite(
|
43
42
|
pytorch_model: torch.nn.Module,
|
44
|
-
|
43
|
+
output_path: str,
|
44
|
+
output_name_prefix: str,
|
45
45
|
prefill_seq_len: Union[int, list[int]],
|
46
46
|
pixel_values_size: torch.Size = None,
|
47
47
|
quantize: bool = True,
|
48
48
|
config: cfg.ModelConfig = None,
|
49
|
+
lora_ranks: Optional[list[int]] = None,
|
49
50
|
export_config: ExportConfig = None,
|
50
51
|
):
|
51
52
|
"""Converts a nn.Module model to multi-signature tflite model.
|
@@ -79,21 +80,65 @@ def convert_to_tflite(
|
|
79
80
|
|
80
81
|
Args:
|
81
82
|
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
82
|
-
|
83
|
-
|
84
|
-
|
83
|
+
output_path (str): The path to export the tflite model.
|
84
|
+
output_name_prefix (str): The prefix of the tflite model name.
|
85
|
+
prefill_seq_len (Union[int, list[int]]): The prefill sequence length to
|
86
|
+
use. If a list, the model will have multiple prefill signatures.
|
85
87
|
pixel_values_size (torch.Size, optional): The size of pixel values to pass
|
86
88
|
to the model. If None, the model is not expected to take pixel values.
|
87
89
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
88
90
|
to True.
|
89
91
|
config (cfg.ModelConfig, optional): The model config used to configure KV
|
90
92
|
cache. If None, it uses the config of the pytorch_model.
|
93
|
+
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
|
94
|
+
no LoRA signatures will be added.
|
91
95
|
"""
|
96
|
+
# pylint: disable=protected-access
|
97
|
+
torch._dynamo.config.cache_size_limit = 64
|
98
|
+
|
99
|
+
config = config if config else pytorch_model.config
|
92
100
|
prefill_seq_lens = (
|
93
101
|
[prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
|
94
102
|
)
|
103
|
+
loras = [None]
|
104
|
+
if lora_ranks is not None:
|
105
|
+
for rank in lora_ranks:
|
106
|
+
lora = lora_utils.LoRA.zeros(rank, config)
|
107
|
+
loras.append(lora)
|
108
|
+
|
109
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
110
|
+
kv_size = config.kv_cache_max_len
|
111
|
+
lora_suffix = (
|
112
|
+
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
113
|
+
)
|
114
|
+
output_filename = (
|
115
|
+
f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
|
116
|
+
)
|
117
|
+
output_file = os.path.join(output_path, output_filename)
|
118
|
+
|
119
|
+
_export_helper(
|
120
|
+
pytorch_model,
|
121
|
+
output_file,
|
122
|
+
prefill_seq_lens,
|
123
|
+
pixel_values_size,
|
124
|
+
quantize,
|
125
|
+
config,
|
126
|
+
loras,
|
127
|
+
export_config,
|
128
|
+
)
|
95
129
|
|
96
|
-
|
130
|
+
|
131
|
+
def _export_helper(
|
132
|
+
pytorch_model: torch.nn.Module,
|
133
|
+
output_file: str,
|
134
|
+
prefill_seq_lens: list[int],
|
135
|
+
pixel_values_size: torch.Size,
|
136
|
+
quantize: bool,
|
137
|
+
config: cfg.ModelConfig,
|
138
|
+
loras: list[None | lora_utils.LoRA],
|
139
|
+
export_config: ExportConfig,
|
140
|
+
):
|
141
|
+
"""Helper function to export a model to tflite."""
|
97
142
|
prefill_tokens_list = []
|
98
143
|
prefill_input_pos_list = []
|
99
144
|
for seq_len in prefill_seq_lens:
|
@@ -108,9 +153,7 @@ def convert_to_tflite(
|
|
108
153
|
|
109
154
|
decode_token = torch.tensor([[0]], dtype=torch.int)
|
110
155
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
111
|
-
kv = kv_utils.KVCache.from_model_config(
|
112
|
-
config if config else pytorch_model.config
|
113
|
-
)
|
156
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
114
157
|
|
115
158
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
116
159
|
|
@@ -119,44 +162,54 @@ def convert_to_tflite(
|
|
119
162
|
mod = ExportableModule(pytorch_model, export_config=export_config)
|
120
163
|
|
121
164
|
converter = converter_utils.Converter()
|
122
|
-
for
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
165
|
+
for lora in loras:
|
166
|
+
for i in range(len(prefill_seq_lens)):
|
167
|
+
prefill_seq_len = prefill_seq_lens[i]
|
168
|
+
prefill_tokens = prefill_tokens_list[i]
|
169
|
+
prefill_input_pos = prefill_input_pos_list[i]
|
170
|
+
if i == 0 and len(prefill_seq_lens) == 1:
|
171
|
+
prefill_signature_name = 'prefill'
|
172
|
+
else:
|
173
|
+
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
174
|
+
|
175
|
+
sample_kwargs = {
|
176
|
+
'tokens': prefill_tokens,
|
177
|
+
'input_pos': prefill_input_pos,
|
178
|
+
'kv_cache': kv,
|
179
|
+
}
|
180
|
+
if lora is not None:
|
181
|
+
prefill_signature_name += f'_lora_r{lora.get_rank()}'
|
182
|
+
sample_kwargs['lora'] = lora
|
183
|
+
|
140
184
|
converter.add_signature(
|
141
|
-
prefill_signature_name
|
185
|
+
prefill_signature_name,
|
142
186
|
mod,
|
143
|
-
sample_kwargs=
|
144
|
-
'tokens': prefill_tokens,
|
145
|
-
'input_pos': prefill_input_pos,
|
146
|
-
'kv_cache': kv,
|
147
|
-
'pixel_values': prefill_pixel_values,
|
148
|
-
},
|
187
|
+
sample_kwargs=sample_kwargs,
|
149
188
|
)
|
150
189
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
190
|
+
if prefill_pixel_values is not None:
|
191
|
+
converter.add_signature(
|
192
|
+
prefill_signature_name + '_pixel',
|
193
|
+
mod,
|
194
|
+
sample_kwargs={
|
195
|
+
**sample_kwargs,
|
196
|
+
'pixel_values': prefill_pixel_values,
|
197
|
+
},
|
198
|
+
)
|
199
|
+
|
200
|
+
sample_kwargs = {
|
201
|
+
'tokens': decode_token,
|
202
|
+
'input_pos': decode_input_pos,
|
203
|
+
'kv_cache': kv,
|
204
|
+
}
|
205
|
+
if lora is not None:
|
206
|
+
sample_kwargs['lora'] = lora
|
207
|
+
|
208
|
+
converter.add_signature(
|
209
|
+
'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
|
210
|
+
mod,
|
211
|
+
sample_kwargs=sample_kwargs,
|
212
|
+
)
|
160
213
|
|
161
214
|
edge_model = converter.convert(quant_config=quant_config)
|
162
|
-
edge_model.export(
|
215
|
+
edge_model.export(output_file)
|
@@ -22,13 +22,14 @@ from typing import Optional, Tuple
|
|
22
22
|
from ai_edge_torch.generative.layers import attention
|
23
23
|
from ai_edge_torch.generative.layers import builder
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
25
26
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
27
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
28
28
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
29
29
|
import torch
|
30
30
|
from torch import nn
|
31
31
|
|
32
|
+
|
32
33
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
33
34
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
34
35
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
@@ -86,6 +87,13 @@ class DecoderOnlyModel(nn.Module):
|
|
86
87
|
config.embedding_dim,
|
87
88
|
config.final_norm_config,
|
88
89
|
)
|
90
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
91
|
+
attn_config = config.block_config(0).attn_config
|
92
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
93
|
+
size=config.kv_cache_max,
|
94
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
95
|
+
base=attn_config.rotary_base,
|
96
|
+
)
|
89
97
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
90
98
|
size=config.kv_cache_max,
|
91
99
|
)
|
@@ -97,6 +105,7 @@ class DecoderOnlyModel(nn.Module):
|
|
97
105
|
tokens: torch.Tensor,
|
98
106
|
input_pos: torch.Tensor,
|
99
107
|
kv_cache: kv_utils.KVCache,
|
108
|
+
lora: Optional[lora_utils.LoRA] = None,
|
100
109
|
export_config: Optional[ExportConfig] = None,
|
101
110
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
102
111
|
_, seq_len = tokens.size()
|
@@ -107,28 +116,23 @@ class DecoderOnlyModel(nn.Module):
|
|
107
116
|
|
108
117
|
# token embeddings of shape (b, t, n_embd)
|
109
118
|
input_embeds = self.tok_embedding(tokens)
|
110
|
-
|
111
|
-
|
112
|
-
attn_config = self.config.block_config(0).attn_config
|
113
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
114
|
-
rope = rotary_pos_emb.build_rope(
|
115
|
-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
116
|
-
)
|
117
|
-
|
119
|
+
cos, sin = self.rope_cache
|
120
|
+
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
|
118
121
|
mask = self.mask_cache.index_select(2, input_pos)
|
119
122
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
120
123
|
|
121
|
-
return self.
|
122
|
-
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
124
|
+
return self.forward_with_embeds(
|
125
|
+
input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
|
123
126
|
)
|
124
127
|
|
125
|
-
def
|
128
|
+
def forward_with_embeds(
|
126
129
|
self,
|
127
130
|
input_embeds: torch.Tensor,
|
128
131
|
rope: Tuple[torch.Tensor, torch.Tensor],
|
129
132
|
mask: torch.Tensor,
|
130
133
|
input_pos: torch.Tensor,
|
131
134
|
kv_cache: kv_utils.KVCache,
|
135
|
+
lora: Optional[lora_utils.LoRA] = None,
|
132
136
|
export_config: Optional[ExportConfig] = None,
|
133
137
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
134
138
|
"""Forwards the model with input embeddings."""
|
@@ -141,13 +145,14 @@ class DecoderOnlyModel(nn.Module):
|
|
141
145
|
if self.config.embedding_scale is not None:
|
142
146
|
x = x * self.config.embedding_scale
|
143
147
|
|
144
|
-
|
148
|
+
updated_kv_entires = []
|
145
149
|
for i, block in enumerate(self.transformer_blocks):
|
146
150
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
147
|
-
|
151
|
+
lora_adapter = lora.adapters[i] if lora else None
|
152
|
+
x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
|
148
153
|
if kv_entry:
|
149
|
-
|
150
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
154
|
+
updated_kv_entires.append(kv_entry)
|
155
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
151
156
|
|
152
157
|
if export_config is not None:
|
153
158
|
if (
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Any,List
|
19
|
+
from typing import Any, List, Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
@@ -134,7 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
134
134
|
prompts: torch.Tensor,
|
135
135
|
max_new_tokens: int,
|
136
136
|
pixel_values: torch.Tensor = None,
|
137
|
-
eos_token_id: int =
|
137
|
+
eos_token_id: Optional[int] = None,
|
138
138
|
) -> torch.IntTensor:
|
139
139
|
input_ids = prompts[0].int().tolist()
|
140
140
|
tokens = torch.tensor([input_ids])
|
@@ -146,7 +146,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
146
146
|
)
|
147
147
|
generated_token = logits[0][-1].argmax().item()
|
148
148
|
input_ids.append(generated_token)
|
149
|
-
if generated_token == eos_token_id:
|
149
|
+
if eos_token_id is not None and generated_token == eos_token_id:
|
150
150
|
break
|
151
151
|
tokens = torch.tensor([[generated_token]])
|
152
152
|
input_pos = torch.tensor([len(input_ids) - 1])
|
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
|
|
253
253
|
outputs_reauthored = reauthored_model.generate(
|
254
254
|
prompt_tokens,
|
255
255
|
max_new_tokens,
|
256
|
-
eos_token_id=tokenizer.tokenizer
|
256
|
+
eos_token_id=getattr(tokenizer.tokenizer, "eos_token_id", None),
|
257
257
|
)
|
258
258
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
259
259
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
@@ -73,3 +73,16 @@ def safe_run_decompositions(exported_program, decomp_table=None):
|
|
73
73
|
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
74
74
|
|
75
75
|
return exported_program.run_decompositions(decomp_table)
|
76
|
+
|
77
|
+
|
78
|
+
def dummy_decomp_table():
|
79
|
+
"""Build dummy decomp table for run_decompositions without any decompositions.
|
80
|
+
|
81
|
+
Compatible for torch<=2.5.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
Decomp table for ExportedProgram.run_decompositions.
|
85
|
+
"""
|
86
|
+
return {
|
87
|
+
torch._ops.OperatorBase(): lambda: None,
|
88
|
+
}
|
@@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
238
238
|
def in_i32(x: int):
|
239
239
|
return -2147483648 <= x <= 2147483647
|
240
240
|
|
241
|
+
def to_int32(x: torch.Tensor):
|
242
|
+
return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
|
243
|
+
|
241
244
|
def rewrite_arange(node: torch.fx.Node):
|
242
245
|
tensor_meta = node.meta.get("tensor_meta", None)
|
243
246
|
if not tensor_meta:
|
@@ -249,7 +252,7 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
249
252
|
if not (in_i32(start) and in_i32(end)):
|
250
253
|
return
|
251
254
|
op = node.target
|
252
|
-
node.target = lambda *args, **kwargs: op(*args, **kwargs)
|
255
|
+
node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
|
253
256
|
|
254
257
|
graph_module = exported_program.graph_module
|
255
258
|
for node in graph_module.graph.nodes:
|
@@ -305,8 +308,9 @@ def exported_program_to_mlir(
|
|
305
308
|
|
306
309
|
_convert_i64_to_i32(exported_program)
|
307
310
|
|
311
|
+
# No decompositions but just retracing/cananicalization.
|
308
312
|
exported_program = _torch_future.safe_run_decompositions(
|
309
|
-
exported_program,
|
313
|
+
exported_program, _torch_future.dummy_decomp_table()
|
310
314
|
)
|
311
315
|
|
312
316
|
# Passes below mutate the exported program to a state not executable by torch.
|
@@ -55,6 +55,10 @@ def decompositions():
|
|
55
55
|
],
|
56
56
|
)
|
57
57
|
|
58
|
+
# Override noop aten op decompositions for faster run_decompositions.
|
59
|
+
decompositions[torch.ops.aten.alias.default] = lambda x: x
|
60
|
+
decompositions[torch.ops.aten.detach.default] = lambda x: x
|
61
|
+
|
58
62
|
# Override _safe_softmax decompositions with regular softmax.
|
59
63
|
# _safe_softmax introduces additional check-select ops to guard extreme
|
60
64
|
# input values to softmax, which could make the converted model inefficient
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250108
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|