ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ class ActivationType(enum.Enum):
|
|
30
30
|
GELU_QUICK = enum.auto()
|
31
31
|
GE_GLU = enum.auto()
|
32
32
|
RELU = enum.auto()
|
33
|
+
SILU_GLU = enum.auto()
|
33
34
|
|
34
35
|
|
35
36
|
@enum.unique
|
@@ -58,6 +59,18 @@ class AttentionType(enum.Enum):
|
|
58
59
|
LOCAL_SLIDING = enum.auto()
|
59
60
|
|
60
61
|
|
62
|
+
@dataclass
|
63
|
+
class NormalizationConfig:
|
64
|
+
"""Normalizater parameters."""
|
65
|
+
|
66
|
+
type: NormalizationType = NormalizationType.NONE
|
67
|
+
enable_hlfb: bool = False
|
68
|
+
epsilon: float = 1e-5
|
69
|
+
zero_centered: bool = False
|
70
|
+
# Number of groups used in group normalization.
|
71
|
+
group_num: Optional[float] = None
|
72
|
+
|
73
|
+
|
61
74
|
@dataclass
|
62
75
|
class AttentionConfig:
|
63
76
|
"""Attention model's parameters."""
|
@@ -81,6 +94,14 @@ class AttentionConfig:
|
|
81
94
|
# Whether to use bias with attention output projection.
|
82
95
|
output_proj_use_bias: bool = False
|
83
96
|
enable_kv_cache: bool = True
|
97
|
+
# The normalization applied to query projection's output.
|
98
|
+
query_norm_config: NormalizationConfig = field(
|
99
|
+
default_factory=NormalizationConfig
|
100
|
+
)
|
101
|
+
# The normalization applied to key projection's output.
|
102
|
+
key_norm_config: NormalizationConfig = field(
|
103
|
+
default_factory=NormalizationConfig
|
104
|
+
)
|
84
105
|
relative_attention_num_buckets: int = 0
|
85
106
|
relative_attention_max_distance: int = 0
|
86
107
|
# Softcap on the output logits.
|
@@ -94,21 +115,9 @@ class AttentionConfig:
|
|
94
115
|
@dataclass
|
95
116
|
class ActivationConfig:
|
96
117
|
type: ActivationType = ActivationType.LINEAR
|
97
|
-
#
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
@dataclass
|
103
|
-
class NormalizationConfig:
|
104
|
-
"""Normalizater parameters."""
|
105
|
-
|
106
|
-
type: NormalizationType = NormalizationType.NONE
|
107
|
-
enable_hlfb: bool = False
|
108
|
-
epsilon: float = 1e-5
|
109
|
-
zero_centered: bool = False
|
110
|
-
# Number of groups used in group normalization.
|
111
|
-
group_num: Optional[float] = None
|
118
|
+
# Whether to GLU gate is the front part instead of the back part of input
|
119
|
+
# when ActivationType is `GE_GLU` or `SILU_GLU`.
|
120
|
+
gate_is_front: bool = False
|
112
121
|
|
113
122
|
|
114
123
|
@dataclass
|
@@ -25,9 +25,9 @@ def main():
|
|
25
25
|
config = gemma.get_fake_model_config()
|
26
26
|
model = gemma.Gemma(config)
|
27
27
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
28
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
28
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
29
29
|
tokens[0, :4] = idx
|
30
|
-
input_pos = torch.arange(0, 10)
|
30
|
+
input_pos = torch.arange(0, 10, dtype=torch.int)
|
31
31
|
|
32
32
|
# Create a quantization recipe to be applied to the model
|
33
33
|
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
@@ -42,15 +42,9 @@ class TestModelConversion(googletest.TestCase):
|
|
42
42
|
)
|
43
43
|
)
|
44
44
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
)
|
49
|
-
def test_toy_model_with_kv_cache(self):
|
50
|
-
config = toy_model_with_kv_cache.get_model_config()
|
51
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
52
|
-
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
53
|
-
[10], dtype=torch.int64
|
45
|
+
def _test_model_with_kv_cache(self, config, pytorch_model):
|
46
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
47
|
+
[10], dtype=torch.int
|
54
48
|
)
|
55
49
|
kv = kv_cache.KVCache.from_model_config(config)
|
56
50
|
|
@@ -83,58 +77,32 @@ class TestModelConversion(googletest.TestCase):
|
|
83
77
|
ai_edge_config.Config.use_torch_xla,
|
84
78
|
reason="tests with custom ops are not supported on oss",
|
85
79
|
)
|
86
|
-
def
|
80
|
+
def test_toy_model_with_kv_cache(self):
|
87
81
|
config = toy_model_with_kv_cache.get_model_config()
|
88
|
-
config.enable_hlfb = True
|
89
82
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
90
|
-
|
91
|
-
[10], dtype=torch.int64
|
92
|
-
)
|
93
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
94
|
-
|
95
|
-
edge_model = ai_edge_torch.convert(
|
96
|
-
pytorch_model,
|
97
|
-
sample_kwargs={
|
98
|
-
"tokens": tokens,
|
99
|
-
"input_pos": input_pos,
|
100
|
-
"kv_cache": kv,
|
101
|
-
},
|
102
|
-
)
|
103
|
-
edge_model.set_interpreter_builder(
|
104
|
-
self._interpreter_builder(edge_model.tflite_model())
|
105
|
-
)
|
106
|
-
|
107
|
-
self.assertTrue(
|
108
|
-
test_utils.compare_tflite_torch(
|
109
|
-
edge_model,
|
110
|
-
pytorch_model,
|
111
|
-
tokens,
|
112
|
-
input_pos,
|
113
|
-
kv,
|
114
|
-
signature_name="serving_default",
|
115
|
-
atol=1e-5,
|
116
|
-
rtol=1e-5,
|
117
|
-
)
|
118
|
-
)
|
83
|
+
self._test_model_with_kv_cache(config, pytorch_model)
|
119
84
|
|
120
85
|
@googletest.skipIf(
|
121
86
|
ai_edge_config.Config.use_torch_xla,
|
122
87
|
reason="tests with custom ops are not supported on oss",
|
123
88
|
)
|
124
|
-
def
|
125
|
-
config =
|
126
|
-
|
89
|
+
def test_toy_model_with_kv_cache_with_hlfb(self):
|
90
|
+
config = toy_model_with_kv_cache.get_model_config()
|
91
|
+
config.enable_hlfb = True
|
92
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
93
|
+
self._test_model_with_kv_cache(config, pytorch_model)
|
127
94
|
|
95
|
+
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
128
96
|
# prefill
|
129
97
|
seq_len = 10
|
130
|
-
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
98
|
+
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
|
131
99
|
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
132
100
|
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
133
|
-
prefill_input_pos = torch.arange(0, seq_len)
|
101
|
+
prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
134
102
|
|
135
103
|
# decode
|
136
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
137
|
-
decode_input_pos = torch.tensor([5], dtype=torch.
|
104
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
105
|
+
decode_input_pos = torch.tensor([5], dtype=torch.int)
|
138
106
|
|
139
107
|
kv = kv_cache.KVCache.from_model_config(config)
|
140
108
|
|
@@ -171,8 +139,8 @@ class TestModelConversion(googletest.TestCase):
|
|
171
139
|
prefill_input_pos,
|
172
140
|
kv,
|
173
141
|
signature_name="prefill",
|
174
|
-
atol=
|
175
|
-
rtol=
|
142
|
+
atol=atol,
|
143
|
+
rtol=atol,
|
176
144
|
)
|
177
145
|
)
|
178
146
|
|
@@ -184,11 +152,20 @@ class TestModelConversion(googletest.TestCase):
|
|
184
152
|
decode_input_pos,
|
185
153
|
kv,
|
186
154
|
signature_name="decode",
|
187
|
-
atol=
|
188
|
-
rtol=
|
155
|
+
atol=atol,
|
156
|
+
rtol=atol,
|
189
157
|
)
|
190
158
|
)
|
191
159
|
|
160
|
+
@googletest.skipIf(
|
161
|
+
ai_edge_config.Config.use_torch_xla,
|
162
|
+
reason="tests with custom ops are not supported on oss",
|
163
|
+
)
|
164
|
+
def test_tiny_llama_multisig(self):
|
165
|
+
config = tiny_llama.get_fake_model_config()
|
166
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
167
|
+
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
168
|
+
|
192
169
|
|
193
170
|
if __name__ == "__main__":
|
194
171
|
googletest.main()
|
@@ -19,7 +19,9 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
22
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
25
|
from ai_edge_torch.generative.layers import kv_cache
|
24
26
|
from ai_edge_torch.generative.test import utils as test_utils
|
25
27
|
import numpy as np
|
@@ -43,28 +45,22 @@ class TestModelConversion(googletest.TestCase):
|
|
43
45
|
)
|
44
46
|
)
|
45
47
|
|
46
|
-
|
47
|
-
ai_edge_config.Config.use_torch_xla,
|
48
|
-
reason="tests with custom ops are not supported on oss",
|
49
|
-
)
|
50
|
-
def test_gemma(self):
|
51
|
-
config = gemma.get_fake_model_config()
|
52
|
-
model = gemma.Gemma(config)
|
53
|
-
|
48
|
+
def _test_model(self, config, model, signature_name, atol, rtol):
|
54
49
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
55
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
50
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
56
51
|
tokens[0, :4] = idx
|
57
|
-
input_pos = torch.arange(0, 10)
|
52
|
+
input_pos = torch.arange(0, 10, dtype=torch.int)
|
58
53
|
kv = kv_cache.KVCache.from_model_config(config)
|
59
54
|
|
60
|
-
edge_model = ai_edge_torch.
|
55
|
+
edge_model = ai_edge_torch.signature(
|
56
|
+
signature_name,
|
61
57
|
model,
|
62
58
|
sample_kwargs={
|
63
59
|
"tokens": tokens,
|
64
60
|
"input_pos": input_pos,
|
65
61
|
"kv_cache": kv,
|
66
62
|
},
|
67
|
-
)
|
63
|
+
).convert()
|
68
64
|
edge_model.set_interpreter_builder(
|
69
65
|
self._interpreter_builder(edge_model.tflite_model())
|
70
66
|
)
|
@@ -76,9 +72,9 @@ class TestModelConversion(googletest.TestCase):
|
|
76
72
|
tokens,
|
77
73
|
input_pos,
|
78
74
|
kv,
|
79
|
-
signature_name=
|
80
|
-
atol=
|
81
|
-
rtol=
|
75
|
+
signature_name=signature_name,
|
76
|
+
atol=atol,
|
77
|
+
rtol=rtol,
|
82
78
|
)
|
83
79
|
)
|
84
80
|
|
@@ -86,42 +82,21 @@ class TestModelConversion(googletest.TestCase):
|
|
86
82
|
ai_edge_config.Config.use_torch_xla,
|
87
83
|
reason="tests with custom ops are not supported on oss",
|
88
84
|
)
|
89
|
-
def
|
90
|
-
config =
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
95
|
-
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
96
|
-
prefill_tokens[0, :4] = idx
|
97
|
-
prefill_input_pos = torch.arange(0, 10)
|
98
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
99
|
-
|
100
|
-
edge_model = ai_edge_torch.signature(
|
101
|
-
"prefill",
|
102
|
-
model,
|
103
|
-
sample_kwargs={
|
104
|
-
"tokens": prefill_tokens,
|
105
|
-
"input_pos": prefill_input_pos,
|
106
|
-
"kv_cache": kv,
|
107
|
-
},
|
108
|
-
).convert()
|
109
|
-
edge_model.set_interpreter_builder(
|
110
|
-
self._interpreter_builder(edge_model.tflite_model())
|
85
|
+
def test_gemma(self):
|
86
|
+
config = gemma.get_fake_model_config()
|
87
|
+
pytorch_model = gemma.Gemma(config).eval()
|
88
|
+
self._test_model(
|
89
|
+
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
111
90
|
)
|
112
91
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
atol=1e-1,
|
122
|
-
rtol=1e-3,
|
123
|
-
)
|
124
|
-
)
|
92
|
+
@googletest.skipIf(
|
93
|
+
ai_edge_config.Config.use_torch_xla,
|
94
|
+
reason="tests with custom ops are not supported on oss",
|
95
|
+
)
|
96
|
+
def test_gemma2(self):
|
97
|
+
config = gemma2.get_fake_model_config()
|
98
|
+
pytorch_model = gemma2.Gemma2(config).eval()
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
|
125
100
|
|
126
101
|
@googletest.skipIf(
|
127
102
|
ai_edge_config.Config.use_torch_xla,
|
@@ -130,37 +105,27 @@ class TestModelConversion(googletest.TestCase):
|
|
130
105
|
def test_phi2(self):
|
131
106
|
config = phi2.get_fake_model_config()
|
132
107
|
pytorch_model = phi2.Phi2(config).eval()
|
133
|
-
|
134
|
-
|
135
|
-
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
136
|
-
tokens[0, :4] = idx
|
137
|
-
input_pos = torch.arange(0, 10)
|
138
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
139
|
-
|
140
|
-
edge_model = ai_edge_torch.convert(
|
141
|
-
pytorch_model,
|
142
|
-
sample_kwargs={
|
143
|
-
"tokens": tokens,
|
144
|
-
"input_pos": input_pos,
|
145
|
-
"kv_cache": kv,
|
146
|
-
},
|
147
|
-
)
|
148
|
-
edge_model.set_interpreter_builder(
|
149
|
-
self._interpreter_builder(edge_model.tflite_model())
|
108
|
+
self._test_model(
|
109
|
+
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
150
110
|
)
|
151
111
|
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
112
|
+
@googletest.skipIf(
|
113
|
+
ai_edge_config.Config.use_torch_xla,
|
114
|
+
reason="tests with custom ops are not supported on oss",
|
115
|
+
)
|
116
|
+
def test_smollm(self):
|
117
|
+
config = smollm.get_fake_model_config()
|
118
|
+
pytorch_model = smollm.SmolLM(config).eval()
|
119
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
120
|
+
|
121
|
+
@googletest.skipIf(
|
122
|
+
ai_edge_config.Config.use_torch_xla,
|
123
|
+
reason="tests with custom ops are not supported on oss",
|
124
|
+
)
|
125
|
+
def test_openelm(self):
|
126
|
+
config = openelm.get_fake_model_config()
|
127
|
+
pytorch_model = openelm.OpenELM(config).eval()
|
128
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
164
129
|
|
165
130
|
|
166
131
|
if __name__ == "__main__":
|
@@ -115,8 +115,8 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
115
115
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
116
116
|
config = toy_model.get_model_config()
|
117
117
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
118
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
119
|
-
input_pos = torch.arange(0, 100)
|
118
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
119
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
120
120
|
|
121
121
|
quantized_model = ai_edge_torch.convert(
|
122
122
|
pytorch_model, (idx, input_pos), quant_config=quant_config
|
@@ -131,8 +131,8 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
131
131
|
def test_quantize_convert_toy_weight_sharing(self):
|
132
132
|
config = toy_model.get_model_config()
|
133
133
|
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
|
134
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
135
|
-
input_pos = torch.arange(0, 100)
|
134
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
135
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
136
136
|
|
137
137
|
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
138
138
|
quantized_model = ai_edge_torch.convert(
|
@@ -149,7 +149,7 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
149
149
|
self.skipTest("b/338288901")
|
150
150
|
config = toy_model_with_kv_cache.get_model_config()
|
151
151
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
152
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.
|
152
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
153
153
|
[10], dtype=torch.int64
|
154
154
|
)
|
155
155
|
|
@@ -101,6 +101,8 @@ class ModelLoader:
|
|
101
101
|
attn_value_proj: str = None
|
102
102
|
attn_fused_qkv_proj: str = None
|
103
103
|
attn_output_proj: str = None
|
104
|
+
attn_query_norm: str = None
|
105
|
+
attn_key_norm: str = None
|
104
106
|
|
105
107
|
ff_up_proj: str = None
|
106
108
|
ff_down_proj: str = None
|
@@ -323,6 +325,17 @@ class ModelLoader:
|
|
323
325
|
)
|
324
326
|
)
|
325
327
|
|
328
|
+
if self._names.attn_query_norm is not None:
|
329
|
+
attn_query_norm_name = self._names.attn_query_norm.format(idx)
|
330
|
+
converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
|
331
|
+
f"{attn_query_norm_name}.weight"
|
332
|
+
)
|
333
|
+
if self._names.attn_key_norm is not None:
|
334
|
+
attn_key_norm_name = self._names.attn_key_norm.format(idx)
|
335
|
+
converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
|
336
|
+
f"{attn_key_norm_name}.weight"
|
337
|
+
)
|
338
|
+
|
326
339
|
o_name = self._names.attn_output_proj.format(idx)
|
327
340
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
328
341
|
state.pop(f"{o_name}.weight")
|
@@ -223,6 +223,41 @@ class MlirLowered:
|
|
223
223
|
return tf_integration.mlir_to_flatbuffer(self)
|
224
224
|
|
225
225
|
|
226
|
+
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
227
|
+
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
228
|
+
"""Convert internal constant aten ops' output from int64 to int32.
|
229
|
+
|
230
|
+
Int32 generally has better performance and compatibility than int64 in
|
231
|
+
runtime. This pass converts aten op where the output(s) are int64 constant
|
232
|
+
tensors to return int32 constant tensors.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
exported_program: The exported program to apply the pass.
|
236
|
+
"""
|
237
|
+
|
238
|
+
def in_i32(x: int):
|
239
|
+
return -2147483648 <= x <= 2147483647
|
240
|
+
|
241
|
+
def rewrite_arange(node: torch.fx.Node):
|
242
|
+
tensor_meta = node.meta.get("tensor_meta", None)
|
243
|
+
if not tensor_meta:
|
244
|
+
return
|
245
|
+
|
246
|
+
start, end = node.args[:2]
|
247
|
+
if tensor_meta.dtype != torch.int64:
|
248
|
+
return
|
249
|
+
if not (in_i32(start) and in_i32(end)):
|
250
|
+
return
|
251
|
+
op = node.target
|
252
|
+
node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
|
253
|
+
|
254
|
+
graph_module = exported_program.graph_module
|
255
|
+
for node in graph_module.graph.nodes:
|
256
|
+
|
257
|
+
if node.target == torch.ops.aten.arange.start_step:
|
258
|
+
rewrite_arange(node)
|
259
|
+
|
260
|
+
|
226
261
|
def exported_program_to_mlir(
|
227
262
|
exported_program: torch.export.ExportedProgram,
|
228
263
|
) -> MlirLowered:
|
@@ -231,6 +266,11 @@ def exported_program_to_mlir(
|
|
231
266
|
lowerings.decompositions()
|
232
267
|
)
|
233
268
|
|
269
|
+
_convert_i64_to_i32(exported_program)
|
270
|
+
exported_program = exported_program.run_decompositions(
|
271
|
+
lowerings.decompositions()
|
272
|
+
)
|
273
|
+
|
234
274
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
235
275
|
|
236
276
|
module = ir.Module.create()
|
@@ -202,3 +202,47 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
202
202
|
x, y = utils.broadcast_args_if_needed(x, y)
|
203
203
|
|
204
204
|
return stablehlo.divide(x, y)
|
205
|
+
|
206
|
+
|
207
|
+
# Schema:
|
208
|
+
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
|
209
|
+
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
210
|
+
# Torch Reference:
|
211
|
+
# - https://pytorch.org/docs/stable/generated/torch.slice_scatter.html
|
212
|
+
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
|
213
|
+
@lower(torch.ops.aten.slice_scatter)
|
214
|
+
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
215
|
+
start = start or 0
|
216
|
+
end = end or self.type.shape[dim]
|
217
|
+
if start < 0:
|
218
|
+
start = self.type.shape[dim] + start
|
219
|
+
if end < 0:
|
220
|
+
end = self.type.shape[dim] + end
|
221
|
+
|
222
|
+
end = start + step * math.ceil((end - start) / step) - (step - 1)
|
223
|
+
|
224
|
+
padding_low = start
|
225
|
+
padding_high = self.type.shape[dim] - end
|
226
|
+
|
227
|
+
rank = len(self.type.shape)
|
228
|
+
src = stablehlo.pad(
|
229
|
+
src,
|
230
|
+
utils.splat(0, src.type.element_type, []),
|
231
|
+
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
|
232
|
+
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
|
233
|
+
interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
|
234
|
+
)
|
235
|
+
pred = np.ones(self.type.shape, dtype=np.bool_)
|
236
|
+
pred[*[
|
237
|
+
slice(start, end, step) if i == dim else slice(None, None, None)
|
238
|
+
for i in range(rank)
|
239
|
+
]] = False
|
240
|
+
pred = stablehlo.constant(
|
241
|
+
ir.DenseElementsAttr.get(
|
242
|
+
np.packbits(pred, bitorder="little"),
|
243
|
+
type=ir.IntegerType.get_signless(1),
|
244
|
+
shape=pred.shape,
|
245
|
+
)
|
246
|
+
)
|
247
|
+
out = stablehlo.select(pred, self, src)
|
248
|
+
return out
|
@@ -203,7 +203,6 @@ lower_by_torch_xla2(torch.ops.aten.sin)
|
|
203
203
|
lower_by_torch_xla2(torch.ops.aten.sinh)
|
204
204
|
lower_by_torch_xla2(torch.ops.aten.slice)
|
205
205
|
lower_by_torch_xla2(torch.ops.aten.slice_copy)
|
206
|
-
lower_by_torch_xla2(torch.ops.aten.slice_scatter)
|
207
206
|
lower_by_torch_xla2(torch.ops.aten.sort)
|
208
207
|
lower_by_torch_xla2(torch.ops.aten.split)
|
209
208
|
lower_by_torch_xla2(torch.ops.aten.split_copy)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240914
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|