ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +45 -37
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +23 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""A suite of tests to validate LoRA utilities."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
import torch
|
21
|
+
from absl.testing import absltest as googletest
|
22
|
+
from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import
|
23
|
+
|
24
|
+
|
25
|
+
class TestLora(googletest.TestCase):
|
26
|
+
"""Tests for LoRA utilities."""
|
27
|
+
|
28
|
+
def test_safetensors_builder(self):
|
29
|
+
"""Converts a safetensors file to a LoRA module."""
|
30
|
+
|
31
|
+
tensor_names = lora_utils.LoRATensorNames(
|
32
|
+
attn_query_w_a=(
|
33
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight"
|
34
|
+
),
|
35
|
+
attn_query_w_b=(
|
36
|
+
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight"
|
37
|
+
),
|
38
|
+
attn_key_w_a=(
|
39
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight"
|
40
|
+
),
|
41
|
+
attn_key_w_b=(
|
42
|
+
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight"
|
43
|
+
),
|
44
|
+
attn_value_w_a=(
|
45
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight"
|
46
|
+
),
|
47
|
+
attn_value_w_b=(
|
48
|
+
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight"
|
49
|
+
),
|
50
|
+
attn_output_w_a=(
|
51
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight"
|
52
|
+
),
|
53
|
+
attn_output_w_b=(
|
54
|
+
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight"
|
55
|
+
),
|
56
|
+
)
|
57
|
+
|
58
|
+
safetensors_file = resource_loader.get_path_to_datafile(
|
59
|
+
"fixtures/test_lora_rank16.safetensors"
|
60
|
+
)
|
61
|
+
config = self._get_test_config(
|
62
|
+
num_layers=1,
|
63
|
+
head_dim=8,
|
64
|
+
num_query_groups=1,
|
65
|
+
kv_cache_max_len=16,
|
66
|
+
)
|
67
|
+
lora = lora_utils.LoRA.from_safetensors(
|
68
|
+
safetensors_file,
|
69
|
+
scale=1.0,
|
70
|
+
lora_tensor_names=tensor_names,
|
71
|
+
config=config,
|
72
|
+
)
|
73
|
+
self.assertEqual(lora.get_rank(), 16)
|
74
|
+
|
75
|
+
def test_torch_export(self):
|
76
|
+
"""Tests the export of the LoRA module."""
|
77
|
+
|
78
|
+
class TestModel(torch.nn.Module):
|
79
|
+
|
80
|
+
def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor:
|
81
|
+
x += lora_utils.apply_lora(x, lora.adapters[0].attention.query)
|
82
|
+
return x
|
83
|
+
|
84
|
+
n = 1
|
85
|
+
head_dim = 2
|
86
|
+
num_query_groups = 1
|
87
|
+
key_length = 4
|
88
|
+
config = self._get_test_config(
|
89
|
+
num_layers=n,
|
90
|
+
head_dim=head_dim,
|
91
|
+
num_query_groups=num_query_groups,
|
92
|
+
kv_cache_max_len=key_length,
|
93
|
+
)
|
94
|
+
inputs = torch.zeros((n, 1, head_dim))
|
95
|
+
lora = lora_utils.LoRA.zeros(rank=16, config=config)
|
96
|
+
model = TestModel()
|
97
|
+
exported_program = torch.export.export(model, (inputs, lora))
|
98
|
+
input_specs = exported_program.graph_signature.input_specs
|
99
|
+
# 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora,
|
100
|
+
# 2 for output lora.
|
101
|
+
self.assertLen(input_specs, 9)
|
102
|
+
self.assertEqual(input_specs[0].arg.name, "x")
|
103
|
+
self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0")
|
104
|
+
self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0")
|
105
|
+
self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0")
|
106
|
+
self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0")
|
107
|
+
self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0")
|
108
|
+
self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0")
|
109
|
+
self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0")
|
110
|
+
self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0")
|
111
|
+
|
112
|
+
def test_lora_tflite_serialization(self):
|
113
|
+
"""Tests the serialization of the LoRA module."""
|
114
|
+
config = self._get_test_config(
|
115
|
+
num_layers=2,
|
116
|
+
head_dim=8,
|
117
|
+
num_query_groups=1,
|
118
|
+
kv_cache_max_len=16,
|
119
|
+
)
|
120
|
+
lora = lora_utils.LoRA.random(rank=16, config=config)
|
121
|
+
flatbuffer_model = lora.to_tflite()
|
122
|
+
recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
|
123
|
+
self.assertEqual(lora, recovered_lora)
|
124
|
+
|
125
|
+
def _get_test_config(
|
126
|
+
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
127
|
+
):
|
128
|
+
"""Returns a test model config."""
|
129
|
+
attn_config = cfg.AttentionConfig(
|
130
|
+
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
131
|
+
)
|
132
|
+
block_config = cfg.TransformerBlockConfig(
|
133
|
+
attn_config=attn_config, ff_config=None
|
134
|
+
)
|
135
|
+
config = cfg.ModelConfig(
|
136
|
+
kv_cache_max_len=kv_cache_max_len,
|
137
|
+
embedding_dim=head_dim,
|
138
|
+
block_configs=block_config,
|
139
|
+
num_layers=num_layers,
|
140
|
+
max_seq_len=None,
|
141
|
+
vocab_size=None,
|
142
|
+
)
|
143
|
+
return config
|
144
|
+
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
googletest.main()
|
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
|
|
150
150
|
ai_edge_torch.config.in_oss,
|
151
151
|
reason="tests with custom ops are not supported in oss",
|
152
152
|
)
|
153
|
+
|
154
|
+
def test_smollm2(self):
|
155
|
+
config = smollm.get_fake_model_config_v2()
|
156
|
+
pytorch_model = smollm.SmolLM2(config).eval()
|
157
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
158
|
+
@googletest.skipIf(
|
159
|
+
ai_edge_torch.config.in_oss,
|
160
|
+
reason="tests with custom ops are not supported in oss",
|
161
|
+
)
|
162
|
+
|
153
163
|
def test_openelm(self):
|
154
164
|
config = openelm.get_fake_model_config()
|
155
165
|
pytorch_model = openelm.OpenELM(config).eval()
|
@@ -15,16 +15,15 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
|
19
|
-
from typing import
|
20
|
-
|
18
|
+
import os
|
19
|
+
from typing import Optional, Union
|
21
20
|
from ai_edge_torch._convert import converter as converter_utils
|
21
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
22
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
25
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
26
|
import torch
|
27
|
-
import torch.nn as nn
|
28
27
|
|
29
28
|
|
30
29
|
class ExportableModule(torch.nn.Module):
|
@@ -41,11 +40,13 @@ class ExportableModule(torch.nn.Module):
|
|
41
40
|
|
42
41
|
def convert_to_tflite(
|
43
42
|
pytorch_model: torch.nn.Module,
|
44
|
-
|
43
|
+
output_path: str,
|
44
|
+
output_name_prefix: str,
|
45
45
|
prefill_seq_len: Union[int, list[int]],
|
46
46
|
pixel_values_size: torch.Size = None,
|
47
47
|
quantize: bool = True,
|
48
48
|
config: cfg.ModelConfig = None,
|
49
|
+
lora_ranks: Optional[list[int]] = None,
|
49
50
|
export_config: ExportConfig = None,
|
50
51
|
):
|
51
52
|
"""Converts a nn.Module model to multi-signature tflite model.
|
@@ -79,21 +80,65 @@ def convert_to_tflite(
|
|
79
80
|
|
80
81
|
Args:
|
81
82
|
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
82
|
-
|
83
|
-
|
84
|
-
|
83
|
+
output_path (str): The path to export the tflite model.
|
84
|
+
output_name_prefix (str): The prefix of the tflite model name.
|
85
|
+
prefill_seq_len (Union[int, list[int]]): The prefill sequence length to
|
86
|
+
use. If a list, the model will have multiple prefill signatures.
|
85
87
|
pixel_values_size (torch.Size, optional): The size of pixel values to pass
|
86
88
|
to the model. If None, the model is not expected to take pixel values.
|
87
89
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
88
90
|
to True.
|
89
91
|
config (cfg.ModelConfig, optional): The model config used to configure KV
|
90
92
|
cache. If None, it uses the config of the pytorch_model.
|
93
|
+
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
|
94
|
+
no LoRA signatures will be added.
|
91
95
|
"""
|
96
|
+
# pylint: disable=protected-access
|
97
|
+
torch._dynamo.config.cache_size_limit = 64
|
98
|
+
|
99
|
+
config = config if config else pytorch_model.config
|
92
100
|
prefill_seq_lens = (
|
93
101
|
[prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
|
94
102
|
)
|
103
|
+
loras = [None]
|
104
|
+
if lora_ranks is not None:
|
105
|
+
for rank in lora_ranks:
|
106
|
+
lora = lora_utils.LoRA.zeros(rank, config)
|
107
|
+
loras.append(lora)
|
108
|
+
|
109
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
110
|
+
kv_size = config.kv_cache_max_len
|
111
|
+
lora_suffix = (
|
112
|
+
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
113
|
+
)
|
114
|
+
output_filename = (
|
115
|
+
f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
|
116
|
+
)
|
117
|
+
output_file = os.path.join(output_path, output_filename)
|
118
|
+
|
119
|
+
_export_helper(
|
120
|
+
pytorch_model,
|
121
|
+
output_file,
|
122
|
+
prefill_seq_lens,
|
123
|
+
pixel_values_size,
|
124
|
+
quantize,
|
125
|
+
config,
|
126
|
+
loras,
|
127
|
+
export_config,
|
128
|
+
)
|
95
129
|
|
96
|
-
|
130
|
+
|
131
|
+
def _export_helper(
|
132
|
+
pytorch_model: torch.nn.Module,
|
133
|
+
output_file: str,
|
134
|
+
prefill_seq_lens: list[int],
|
135
|
+
pixel_values_size: torch.Size,
|
136
|
+
quantize: bool,
|
137
|
+
config: cfg.ModelConfig,
|
138
|
+
loras: list[None | lora_utils.LoRA],
|
139
|
+
export_config: ExportConfig,
|
140
|
+
):
|
141
|
+
"""Helper function to export a model to tflite."""
|
97
142
|
prefill_tokens_list = []
|
98
143
|
prefill_input_pos_list = []
|
99
144
|
for seq_len in prefill_seq_lens:
|
@@ -108,9 +153,7 @@ def convert_to_tflite(
|
|
108
153
|
|
109
154
|
decode_token = torch.tensor([[0]], dtype=torch.int)
|
110
155
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
111
|
-
kv = kv_utils.KVCache.from_model_config(
|
112
|
-
config if config else pytorch_model.config
|
113
|
-
)
|
156
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
114
157
|
|
115
158
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
116
159
|
|
@@ -119,44 +162,54 @@ def convert_to_tflite(
|
|
119
162
|
mod = ExportableModule(pytorch_model, export_config=export_config)
|
120
163
|
|
121
164
|
converter = converter_utils.Converter()
|
122
|
-
for
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
165
|
+
for lora in loras:
|
166
|
+
for i in range(len(prefill_seq_lens)):
|
167
|
+
prefill_seq_len = prefill_seq_lens[i]
|
168
|
+
prefill_tokens = prefill_tokens_list[i]
|
169
|
+
prefill_input_pos = prefill_input_pos_list[i]
|
170
|
+
if i == 0 and len(prefill_seq_lens) == 1:
|
171
|
+
prefill_signature_name = 'prefill'
|
172
|
+
else:
|
173
|
+
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
174
|
+
|
175
|
+
sample_kwargs = {
|
176
|
+
'tokens': prefill_tokens,
|
177
|
+
'input_pos': prefill_input_pos,
|
178
|
+
'kv_cache': kv,
|
179
|
+
}
|
180
|
+
if lora is not None:
|
181
|
+
prefill_signature_name += f'_lora_r{lora.get_rank()}'
|
182
|
+
sample_kwargs['lora'] = lora
|
183
|
+
|
140
184
|
converter.add_signature(
|
141
|
-
prefill_signature_name
|
185
|
+
prefill_signature_name,
|
142
186
|
mod,
|
143
|
-
sample_kwargs=
|
144
|
-
'tokens': prefill_tokens,
|
145
|
-
'input_pos': prefill_input_pos,
|
146
|
-
'kv_cache': kv,
|
147
|
-
'pixel_values': prefill_pixel_values,
|
148
|
-
},
|
187
|
+
sample_kwargs=sample_kwargs,
|
149
188
|
)
|
150
189
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
190
|
+
if prefill_pixel_values is not None:
|
191
|
+
converter.add_signature(
|
192
|
+
prefill_signature_name + '_pixel',
|
193
|
+
mod,
|
194
|
+
sample_kwargs={
|
195
|
+
**sample_kwargs,
|
196
|
+
'pixel_values': prefill_pixel_values,
|
197
|
+
},
|
198
|
+
)
|
199
|
+
|
200
|
+
sample_kwargs = {
|
201
|
+
'tokens': decode_token,
|
202
|
+
'input_pos': decode_input_pos,
|
203
|
+
'kv_cache': kv,
|
204
|
+
}
|
205
|
+
if lora is not None:
|
206
|
+
sample_kwargs['lora'] = lora
|
207
|
+
|
208
|
+
converter.add_signature(
|
209
|
+
'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
|
210
|
+
mod,
|
211
|
+
sample_kwargs=sample_kwargs,
|
212
|
+
)
|
160
213
|
|
161
214
|
edge_model = converter.convert(quant_config=quant_config)
|
162
|
-
edge_model.export(
|
215
|
+
edge_model.export(output_file)
|
@@ -22,12 +22,15 @@ from typing import Optional, Tuple
|
|
22
22
|
from ai_edge_torch.generative.layers import attention
|
23
23
|
from ai_edge_torch.generative.layers import builder
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
25
26
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
27
|
import ai_edge_torch.generative.layers.model_config as cfg
|
28
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
27
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
30
|
import torch
|
29
31
|
from torch import nn
|
30
32
|
|
33
|
+
|
31
34
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
32
35
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
33
36
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
@@ -85,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
|
|
85
88
|
config.embedding_dim,
|
86
89
|
config.final_norm_config,
|
87
90
|
)
|
88
|
-
# ROPE parameters for all attn_configs are the same. Take the first one.
|
89
|
-
attn_config = config.block_config(0).attn_config
|
90
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
91
|
-
size=config.kv_cache_max,
|
92
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
93
|
-
base=attn_config.rotary_base,
|
94
|
-
)
|
95
91
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
96
92
|
size=config.kv_cache_max,
|
97
93
|
)
|
@@ -103,6 +99,7 @@ class DecoderOnlyModel(nn.Module):
|
|
103
99
|
tokens: torch.Tensor,
|
104
100
|
input_pos: torch.Tensor,
|
105
101
|
kv_cache: kv_utils.KVCache,
|
102
|
+
lora: Optional[lora_utils.LoRA] = None,
|
106
103
|
export_config: Optional[ExportConfig] = None,
|
107
104
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
108
105
|
_, seq_len = tokens.size()
|
@@ -113,13 +110,23 @@ class DecoderOnlyModel(nn.Module):
|
|
113
110
|
|
114
111
|
# token embeddings of shape (b, t, n_embd)
|
115
112
|
input_embeds = self.tok_embedding(tokens)
|
116
|
-
|
117
|
-
|
113
|
+
|
114
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
115
|
+
attn_config = self.config.block_config(0).attn_config
|
116
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
117
|
+
rope = self.config.build_rope(
|
118
|
+
input_pos=input_pos,
|
119
|
+
n_elem=n_elem,
|
120
|
+
base=attn_config.rotary_base,
|
121
|
+
head_dim=attn_config.head_dim,
|
122
|
+
# input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
|
123
|
+
)
|
124
|
+
|
118
125
|
mask = self.mask_cache.index_select(2, input_pos)
|
119
126
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
120
127
|
|
121
128
|
return self.forward_with_embeds(
|
122
|
-
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
129
|
+
input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
|
123
130
|
)
|
124
131
|
|
125
132
|
def forward_with_embeds(
|
@@ -129,6 +136,7 @@ class DecoderOnlyModel(nn.Module):
|
|
129
136
|
mask: torch.Tensor,
|
130
137
|
input_pos: torch.Tensor,
|
131
138
|
kv_cache: kv_utils.KVCache,
|
139
|
+
lora: Optional[lora_utils.LoRA] = None,
|
132
140
|
export_config: Optional[ExportConfig] = None,
|
133
141
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
134
142
|
"""Forwards the model with input embeddings."""
|
@@ -141,13 +149,14 @@ class DecoderOnlyModel(nn.Module):
|
|
141
149
|
if self.config.embedding_scale is not None:
|
142
150
|
x = x * self.config.embedding_scale
|
143
151
|
|
144
|
-
|
152
|
+
updated_kv_entries = []
|
145
153
|
for i, block in enumerate(self.transformer_blocks):
|
146
154
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
147
|
-
|
155
|
+
lora_adapter = lora.adapters[i] if lora else None
|
156
|
+
x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
|
148
157
|
if kv_entry:
|
149
|
-
|
150
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
158
|
+
updated_kv_entries.append(kv_entry)
|
159
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
151
160
|
|
152
161
|
if export_config is not None:
|
153
162
|
if (
|
@@ -17,7 +17,7 @@ from typing import Any
|
|
17
17
|
import uuid
|
18
18
|
|
19
19
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
20
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
22
|
import torch
|
23
23
|
|
@@ -87,7 +87,7 @@ def mark_pattern(
|
|
87
87
|
m.meta["ORIGINAL_NODE"] = n
|
88
88
|
|
89
89
|
# Sanitize graph_module to match in the same way as pattern's graph_module.
|
90
|
-
graph_module_to_match =
|
90
|
+
graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
|
91
91
|
|
92
92
|
match_with_attrs = pattern.match(graph_module_to_match)
|
93
93
|
|
@@ -111,13 +111,25 @@ def mark_pattern(
|
|
111
111
|
is_input=True,
|
112
112
|
)
|
113
113
|
|
114
|
-
# Only replace input by the marker node for those nodes used in the
|
114
|
+
# Only replace input by the marker node for those nodes used in the
|
115
|
+
# pattern.
|
115
116
|
in_pattern_nodes = set(match.nodes_map.values())
|
116
117
|
for user in input_node.users.keys():
|
117
|
-
if user in in_pattern_nodes:
|
118
|
-
|
119
|
-
|
120
|
-
|
118
|
+
if user not in in_pattern_nodes:
|
119
|
+
continue
|
120
|
+
|
121
|
+
user.meta["ORIGINAL_NODE"].replace_input_with(
|
122
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
123
|
+
)
|
124
|
+
# Pattern matching graph sanitization may remove clone ops, which means
|
125
|
+
# the user's input in the original graph may be a clone op. When
|
126
|
+
# replacing the input with the marker node, we need to further try
|
127
|
+
# replacing the input of the clone op that connects to the user.
|
128
|
+
for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
|
129
|
+
if fx_utils.is_clone_op(original_user_input):
|
130
|
+
original_user_input.replace_input_with(
|
131
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
132
|
+
)
|
121
133
|
|
122
134
|
for i, pattern_output_node in enumerate(pattern.output_nodes):
|
123
135
|
output_node = match.nodes_map[pattern_output_node]
|
@@ -12,11 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""FX graph utilities for pattern matching clean ups."""
|
16
16
|
|
17
17
|
import torch
|
18
18
|
|
19
19
|
|
20
|
+
def is_clone_op(node: torch.fx.Node) -> bool:
|
21
|
+
"""Checks if the node is a clone op."""
|
22
|
+
return (
|
23
|
+
node.op == "call_function" and node.target == torch.ops.aten.clone.default
|
24
|
+
)
|
25
|
+
|
26
|
+
|
20
27
|
def remove_clone_ops(gm: torch.fx.GraphModule):
|
21
28
|
"""Removes clone ops from the graph.
|
22
29
|
|
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
32
39
|
The graph module with clone ops removed.
|
33
40
|
"""
|
34
41
|
for node in gm.graph.nodes:
|
35
|
-
if node
|
42
|
+
if is_clone_op(node):
|
36
43
|
node.replace_all_uses_with(node.args[0])
|
37
44
|
gm.graph.erase_node(node)
|
38
45
|
|
@@ -18,13 +18,14 @@ import dataclasses
|
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
21
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
22
22
|
import torch
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
23
|
+
|
24
|
+
Graph = torch.fx.Graph
|
25
|
+
GraphModule = torch.fx.GraphModule
|
26
|
+
TensorArgument = torch.export.graph_signature.TensorArgument
|
27
|
+
InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
|
28
|
+
SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
|
28
29
|
|
29
30
|
|
30
31
|
def _are_equal(x: Any, y: Any) -> bool:
|
@@ -219,8 +220,8 @@ class Pattern:
|
|
219
220
|
# Sanitize graph_module for more precise pattern matching.
|
220
221
|
# The graph_module to match against this pattern should apply equivalent
|
221
222
|
# sanitization.
|
222
|
-
self.graph_module =
|
223
|
-
self.graph_module =
|
223
|
+
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
|
224
|
+
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
|
224
225
|
|
225
226
|
# Builds list of ordered input and output nodes.
|
226
227
|
self.graph_nodes_map = {}
|
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
|
|
58
58
|
{"stablehlo.custom_call @mark_tensor": 6},
|
59
59
|
)
|
60
60
|
|
61
|
+
def test_mark_pattern_with_clone_inputs(self):
|
62
|
+
|
63
|
+
class TestModel(torch.nn.Module):
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
return torch.ops.aten.clone.default(x * x) + x
|
67
|
+
|
68
|
+
pattern = pattern_module.Pattern(
|
69
|
+
"test.add",
|
70
|
+
lambda a, b: a + b,
|
71
|
+
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
72
|
+
)
|
73
|
+
|
74
|
+
model = TestModel().eval()
|
75
|
+
args = (torch.rand(20, 20),)
|
76
|
+
exported_program = torch.export.export(model, args)
|
77
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
78
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
79
|
+
|
80
|
+
lowertools.assert_string_count(
|
81
|
+
self,
|
82
|
+
mlir,
|
83
|
+
{'stablehlo.composite "test.add"': 1},
|
84
|
+
{"stablehlo.custom_call @mark_tensor": 3},
|
85
|
+
)
|
86
|
+
|
61
87
|
def test_mark_pattern_with_attr_builder(self):
|
62
88
|
class TestModel(torch.nn.Module):
|
63
89
|
|
@@ -73,3 +73,16 @@ def safe_run_decompositions(exported_program, decomp_table=None):
|
|
73
73
|
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
74
74
|
|
75
75
|
return exported_program.run_decompositions(decomp_table)
|
76
|
+
|
77
|
+
|
78
|
+
def dummy_decomp_table():
|
79
|
+
"""Build dummy decomp table for run_decompositions without any decompositions.
|
80
|
+
|
81
|
+
Compatible for torch<=2.5.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
Decomp table for ExportedProgram.run_decompositions.
|
85
|
+
"""
|
86
|
+
return {
|
87
|
+
torch._ops.OperatorBase(): lambda: None,
|
88
|
+
}
|
@@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
238
238
|
def in_i32(x: int):
|
239
239
|
return -2147483648 <= x <= 2147483647
|
240
240
|
|
241
|
+
def to_int32(x: torch.Tensor):
|
242
|
+
return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
|
243
|
+
|
241
244
|
def rewrite_arange(node: torch.fx.Node):
|
242
245
|
tensor_meta = node.meta.get("tensor_meta", None)
|
243
246
|
if not tensor_meta:
|
@@ -249,7 +252,7 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
|
249
252
|
if not (in_i32(start) and in_i32(end)):
|
250
253
|
return
|
251
254
|
op = node.target
|
252
|
-
node.target = lambda *args, **kwargs: op(*args, **kwargs)
|
255
|
+
node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
|
253
256
|
|
254
257
|
graph_module = exported_program.graph_module
|
255
258
|
for node in graph_module.graph.nodes:
|
@@ -305,8 +308,9 @@ def exported_program_to_mlir(
|
|
305
308
|
|
306
309
|
_convert_i64_to_i32(exported_program)
|
307
310
|
|
311
|
+
# No decompositions but just retracing/cananicalization.
|
308
312
|
exported_program = _torch_future.safe_run_decompositions(
|
309
|
-
exported_program,
|
313
|
+
exported_program, _torch_future.dummy_decomp_table()
|
310
314
|
)
|
311
315
|
|
312
316
|
# Passes below mutate the exported program to a state not executable by torch.
|
@@ -55,6 +55,10 @@ def decompositions():
|
|
55
55
|
],
|
56
56
|
)
|
57
57
|
|
58
|
+
# Override noop aten op decompositions for faster run_decompositions.
|
59
|
+
decompositions[torch.ops.aten.alias.default] = lambda x: x
|
60
|
+
decompositions[torch.ops.aten.detach.default] = lambda x: x
|
61
|
+
|
58
62
|
# Override _safe_softmax decompositions with regular softmax.
|
59
63
|
# _safe_softmax introduces additional check-select ops to guard extreme
|
60
64
|
# input values to softmax, which could make the converted model inefficient
|
ai_edge_torch/version.py
CHANGED