ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ class ActivationType(enum.Enum):
|
|
30
30
|
GELU_QUICK = enum.auto()
|
31
31
|
GE_GLU = enum.auto()
|
32
32
|
RELU = enum.auto()
|
33
|
+
SILU_GLU = enum.auto()
|
33
34
|
|
34
35
|
|
35
36
|
@enum.unique
|
@@ -58,6 +59,18 @@ class AttentionType(enum.Enum):
|
|
58
59
|
LOCAL_SLIDING = enum.auto()
|
59
60
|
|
60
61
|
|
62
|
+
@dataclass
|
63
|
+
class NormalizationConfig:
|
64
|
+
"""Normalizater parameters."""
|
65
|
+
|
66
|
+
type: NormalizationType = NormalizationType.NONE
|
67
|
+
enable_hlfb: bool = False
|
68
|
+
epsilon: float = 1e-5
|
69
|
+
zero_centered: bool = False
|
70
|
+
# Number of groups used in group normalization.
|
71
|
+
group_num: Optional[float] = None
|
72
|
+
|
73
|
+
|
61
74
|
@dataclass
|
62
75
|
class AttentionConfig:
|
63
76
|
"""Attention model's parameters."""
|
@@ -81,6 +94,14 @@ class AttentionConfig:
|
|
81
94
|
# Whether to use bias with attention output projection.
|
82
95
|
output_proj_use_bias: bool = False
|
83
96
|
enable_kv_cache: bool = True
|
97
|
+
# The normalization applied to query projection's output.
|
98
|
+
query_norm_config: NormalizationConfig = field(
|
99
|
+
default_factory=NormalizationConfig
|
100
|
+
)
|
101
|
+
# The normalization applied to key projection's output.
|
102
|
+
key_norm_config: NormalizationConfig = field(
|
103
|
+
default_factory=NormalizationConfig
|
104
|
+
)
|
84
105
|
relative_attention_num_buckets: int = 0
|
85
106
|
relative_attention_max_distance: int = 0
|
86
107
|
# Softcap on the output logits.
|
@@ -94,21 +115,9 @@ class AttentionConfig:
|
|
94
115
|
@dataclass
|
95
116
|
class ActivationConfig:
|
96
117
|
type: ActivationType = ActivationType.LINEAR
|
97
|
-
#
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
@dataclass
|
103
|
-
class NormalizationConfig:
|
104
|
-
"""Normalizater parameters."""
|
105
|
-
|
106
|
-
type: NormalizationType = NormalizationType.NONE
|
107
|
-
enable_hlfb: bool = False
|
108
|
-
epsilon: float = 1e-5
|
109
|
-
zero_centered: bool = False
|
110
|
-
# Number of groups used in group normalization.
|
111
|
-
group_num: Optional[float] = None
|
118
|
+
# Whether to GLU gate is the front part instead of the back part of input
|
119
|
+
# when ActivationType is `GE_GLU` or `SILU_GLU`.
|
120
|
+
gate_is_front: bool = False
|
112
121
|
|
113
122
|
|
114
123
|
@dataclass
|
@@ -25,9 +25,9 @@ def main():
|
|
25
25
|
config = gemma.get_fake_model_config()
|
26
26
|
model = gemma.Gemma(config)
|
27
27
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
28
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
28
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
29
29
|
tokens[0, :4] = idx
|
30
|
-
input_pos = torch.arange(0, 10)
|
30
|
+
input_pos = torch.arange(0, 10, dtype=torch.int)
|
31
31
|
|
32
32
|
# Create a quantization recipe to be applied to the model
|
33
33
|
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
@@ -42,15 +42,9 @@ class TestModelConversion(googletest.TestCase):
|
|
42
42
|
)
|
43
43
|
)
|
44
44
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
)
|
49
|
-
def test_toy_model_with_kv_cache(self):
|
50
|
-
config = toy_model_with_kv_cache.get_model_config()
|
51
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
52
|
-
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
53
|
-
[10], dtype=torch.int64
|
45
|
+
def _test_model_with_kv_cache(self, config, pytorch_model):
|
46
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
47
|
+
[10], dtype=torch.int
|
54
48
|
)
|
55
49
|
kv = kv_cache.KVCache.from_model_config(config)
|
56
50
|
|
@@ -83,58 +77,32 @@ class TestModelConversion(googletest.TestCase):
|
|
83
77
|
ai_edge_config.Config.use_torch_xla,
|
84
78
|
reason="tests with custom ops are not supported on oss",
|
85
79
|
)
|
86
|
-
def
|
80
|
+
def test_toy_model_with_kv_cache(self):
|
87
81
|
config = toy_model_with_kv_cache.get_model_config()
|
88
|
-
config.enable_hlfb = True
|
89
82
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
90
|
-
|
91
|
-
[10], dtype=torch.int64
|
92
|
-
)
|
93
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
94
|
-
|
95
|
-
edge_model = ai_edge_torch.convert(
|
96
|
-
pytorch_model,
|
97
|
-
sample_kwargs={
|
98
|
-
"tokens": tokens,
|
99
|
-
"input_pos": input_pos,
|
100
|
-
"kv_cache": kv,
|
101
|
-
},
|
102
|
-
)
|
103
|
-
edge_model.set_interpreter_builder(
|
104
|
-
self._interpreter_builder(edge_model.tflite_model())
|
105
|
-
)
|
106
|
-
|
107
|
-
self.assertTrue(
|
108
|
-
test_utils.compare_tflite_torch(
|
109
|
-
edge_model,
|
110
|
-
pytorch_model,
|
111
|
-
tokens,
|
112
|
-
input_pos,
|
113
|
-
kv,
|
114
|
-
signature_name="serving_default",
|
115
|
-
atol=1e-5,
|
116
|
-
rtol=1e-5,
|
117
|
-
)
|
118
|
-
)
|
83
|
+
self._test_model_with_kv_cache(config, pytorch_model)
|
119
84
|
|
120
85
|
@googletest.skipIf(
|
121
86
|
ai_edge_config.Config.use_torch_xla,
|
122
87
|
reason="tests with custom ops are not supported on oss",
|
123
88
|
)
|
124
|
-
def
|
125
|
-
config =
|
126
|
-
|
89
|
+
def test_toy_model_with_kv_cache_with_hlfb(self):
|
90
|
+
config = toy_model_with_kv_cache.get_model_config()
|
91
|
+
config.enable_hlfb = True
|
92
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
93
|
+
self._test_model_with_kv_cache(config, pytorch_model)
|
127
94
|
|
95
|
+
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
128
96
|
# prefill
|
129
97
|
seq_len = 10
|
130
|
-
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
98
|
+
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
|
131
99
|
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
132
100
|
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
133
|
-
prefill_input_pos = torch.arange(0, seq_len)
|
101
|
+
prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
134
102
|
|
135
103
|
# decode
|
136
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
137
|
-
decode_input_pos = torch.tensor([5], dtype=torch.
|
104
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
105
|
+
decode_input_pos = torch.tensor([5], dtype=torch.int)
|
138
106
|
|
139
107
|
kv = kv_cache.KVCache.from_model_config(config)
|
140
108
|
|
@@ -171,8 +139,8 @@ class TestModelConversion(googletest.TestCase):
|
|
171
139
|
prefill_input_pos,
|
172
140
|
kv,
|
173
141
|
signature_name="prefill",
|
174
|
-
atol=
|
175
|
-
rtol=
|
142
|
+
atol=atol,
|
143
|
+
rtol=atol,
|
176
144
|
)
|
177
145
|
)
|
178
146
|
|
@@ -184,11 +152,20 @@ class TestModelConversion(googletest.TestCase):
|
|
184
152
|
decode_input_pos,
|
185
153
|
kv,
|
186
154
|
signature_name="decode",
|
187
|
-
atol=
|
188
|
-
rtol=
|
155
|
+
atol=atol,
|
156
|
+
rtol=atol,
|
189
157
|
)
|
190
158
|
)
|
191
159
|
|
160
|
+
@googletest.skipIf(
|
161
|
+
ai_edge_config.Config.use_torch_xla,
|
162
|
+
reason="tests with custom ops are not supported on oss",
|
163
|
+
)
|
164
|
+
def test_tiny_llama_multisig(self):
|
165
|
+
config = tiny_llama.get_fake_model_config()
|
166
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
167
|
+
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
168
|
+
|
192
169
|
|
193
170
|
if __name__ == "__main__":
|
194
171
|
googletest.main()
|
@@ -19,7 +19,9 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
22
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
25
|
from ai_edge_torch.generative.layers import kv_cache
|
24
26
|
from ai_edge_torch.generative.test import utils as test_utils
|
25
27
|
import numpy as np
|
@@ -43,28 +45,22 @@ class TestModelConversion(googletest.TestCase):
|
|
43
45
|
)
|
44
46
|
)
|
45
47
|
|
46
|
-
|
47
|
-
ai_edge_config.Config.use_torch_xla,
|
48
|
-
reason="tests with custom ops are not supported on oss",
|
49
|
-
)
|
50
|
-
def test_gemma(self):
|
51
|
-
config = gemma.get_fake_model_config()
|
52
|
-
model = gemma.Gemma(config)
|
53
|
-
|
48
|
+
def _test_model(self, config, model, signature_name, atol, rtol):
|
54
49
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
55
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
50
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
56
51
|
tokens[0, :4] = idx
|
57
|
-
input_pos = torch.arange(0, 10)
|
52
|
+
input_pos = torch.arange(0, 10, dtype=torch.int)
|
58
53
|
kv = kv_cache.KVCache.from_model_config(config)
|
59
54
|
|
60
|
-
edge_model = ai_edge_torch.
|
55
|
+
edge_model = ai_edge_torch.signature(
|
56
|
+
signature_name,
|
61
57
|
model,
|
62
58
|
sample_kwargs={
|
63
59
|
"tokens": tokens,
|
64
60
|
"input_pos": input_pos,
|
65
61
|
"kv_cache": kv,
|
66
62
|
},
|
67
|
-
)
|
63
|
+
).convert()
|
68
64
|
edge_model.set_interpreter_builder(
|
69
65
|
self._interpreter_builder(edge_model.tflite_model())
|
70
66
|
)
|
@@ -76,9 +72,9 @@ class TestModelConversion(googletest.TestCase):
|
|
76
72
|
tokens,
|
77
73
|
input_pos,
|
78
74
|
kv,
|
79
|
-
signature_name=
|
80
|
-
atol=
|
81
|
-
rtol=
|
75
|
+
signature_name=signature_name,
|
76
|
+
atol=atol,
|
77
|
+
rtol=rtol,
|
82
78
|
)
|
83
79
|
)
|
84
80
|
|
@@ -86,42 +82,21 @@ class TestModelConversion(googletest.TestCase):
|
|
86
82
|
ai_edge_config.Config.use_torch_xla,
|
87
83
|
reason="tests with custom ops are not supported on oss",
|
88
84
|
)
|
89
|
-
def
|
90
|
-
config =
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
95
|
-
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
96
|
-
prefill_tokens[0, :4] = idx
|
97
|
-
prefill_input_pos = torch.arange(0, 10)
|
98
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
99
|
-
|
100
|
-
edge_model = ai_edge_torch.signature(
|
101
|
-
"prefill",
|
102
|
-
model,
|
103
|
-
sample_kwargs={
|
104
|
-
"tokens": prefill_tokens,
|
105
|
-
"input_pos": prefill_input_pos,
|
106
|
-
"kv_cache": kv,
|
107
|
-
},
|
108
|
-
).convert()
|
109
|
-
edge_model.set_interpreter_builder(
|
110
|
-
self._interpreter_builder(edge_model.tflite_model())
|
85
|
+
def test_gemma(self):
|
86
|
+
config = gemma.get_fake_model_config()
|
87
|
+
pytorch_model = gemma.Gemma(config).eval()
|
88
|
+
self._test_model(
|
89
|
+
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
111
90
|
)
|
112
91
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
atol=1e-1,
|
122
|
-
rtol=1e-3,
|
123
|
-
)
|
124
|
-
)
|
92
|
+
@googletest.skipIf(
|
93
|
+
ai_edge_config.Config.use_torch_xla,
|
94
|
+
reason="tests with custom ops are not supported on oss",
|
95
|
+
)
|
96
|
+
def test_gemma2(self):
|
97
|
+
config = gemma2.get_fake_model_config()
|
98
|
+
pytorch_model = gemma2.Gemma2(config).eval()
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
|
125
100
|
|
126
101
|
@googletest.skipIf(
|
127
102
|
ai_edge_config.Config.use_torch_xla,
|
@@ -130,37 +105,27 @@ class TestModelConversion(googletest.TestCase):
|
|
130
105
|
def test_phi2(self):
|
131
106
|
config = phi2.get_fake_model_config()
|
132
107
|
pytorch_model = phi2.Phi2(config).eval()
|
133
|
-
|
134
|
-
|
135
|
-
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
136
|
-
tokens[0, :4] = idx
|
137
|
-
input_pos = torch.arange(0, 10)
|
138
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
139
|
-
|
140
|
-
edge_model = ai_edge_torch.convert(
|
141
|
-
pytorch_model,
|
142
|
-
sample_kwargs={
|
143
|
-
"tokens": tokens,
|
144
|
-
"input_pos": input_pos,
|
145
|
-
"kv_cache": kv,
|
146
|
-
},
|
147
|
-
)
|
148
|
-
edge_model.set_interpreter_builder(
|
149
|
-
self._interpreter_builder(edge_model.tflite_model())
|
108
|
+
self._test_model(
|
109
|
+
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
150
110
|
)
|
151
111
|
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
112
|
+
@googletest.skipIf(
|
113
|
+
ai_edge_config.Config.use_torch_xla,
|
114
|
+
reason="tests with custom ops are not supported on oss",
|
115
|
+
)
|
116
|
+
def test_smollm(self):
|
117
|
+
config = smollm.get_fake_model_config()
|
118
|
+
pytorch_model = smollm.SmolLM(config).eval()
|
119
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
120
|
+
|
121
|
+
@googletest.skipIf(
|
122
|
+
ai_edge_config.Config.use_torch_xla,
|
123
|
+
reason="tests with custom ops are not supported on oss",
|
124
|
+
)
|
125
|
+
def test_openelm(self):
|
126
|
+
config = openelm.get_fake_model_config()
|
127
|
+
pytorch_model = openelm.OpenELM(config).eval()
|
128
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
164
129
|
|
165
130
|
|
166
131
|
if __name__ == "__main__":
|
@@ -115,8 +115,8 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
115
115
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
116
116
|
config = toy_model.get_model_config()
|
117
117
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
118
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
119
|
-
input_pos = torch.arange(0, 100)
|
118
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
119
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
120
120
|
|
121
121
|
quantized_model = ai_edge_torch.convert(
|
122
122
|
pytorch_model, (idx, input_pos), quant_config=quant_config
|
@@ -131,8 +131,8 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
131
131
|
def test_quantize_convert_toy_weight_sharing(self):
|
132
132
|
config = toy_model.get_model_config()
|
133
133
|
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
|
134
|
-
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
135
|
-
input_pos = torch.arange(0, 100)
|
134
|
+
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
135
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
136
136
|
|
137
137
|
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
138
138
|
quantized_model = ai_edge_torch.convert(
|
@@ -149,7 +149,7 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
149
149
|
self.skipTest("b/338288901")
|
150
150
|
config = toy_model_with_kv_cache.get_model_config()
|
151
151
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
152
|
-
idx, input_pos = torch.tensor([[1]], dtype=torch.
|
152
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
153
153
|
[10], dtype=torch.int64
|
154
154
|
)
|
155
155
|
|
@@ -101,6 +101,8 @@ class ModelLoader:
|
|
101
101
|
attn_value_proj: str = None
|
102
102
|
attn_fused_qkv_proj: str = None
|
103
103
|
attn_output_proj: str = None
|
104
|
+
attn_query_norm: str = None
|
105
|
+
attn_key_norm: str = None
|
104
106
|
|
105
107
|
ff_up_proj: str = None
|
106
108
|
ff_down_proj: str = None
|
@@ -323,6 +325,17 @@ class ModelLoader:
|
|
323
325
|
)
|
324
326
|
)
|
325
327
|
|
328
|
+
if self._names.attn_query_norm is not None:
|
329
|
+
attn_query_norm_name = self._names.attn_query_norm.format(idx)
|
330
|
+
converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
|
331
|
+
f"{attn_query_norm_name}.weight"
|
332
|
+
)
|
333
|
+
if self._names.attn_key_norm is not None:
|
334
|
+
attn_key_norm_name = self._names.attn_key_norm.format(idx)
|
335
|
+
converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
|
336
|
+
f"{attn_key_norm_name}.weight"
|
337
|
+
)
|
338
|
+
|
326
339
|
o_name = self._names.attn_output_proj.format(idx)
|
327
340
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
328
341
|
state.pop(f"{o_name}.weight")
|
@@ -223,6 +223,41 @@ class MlirLowered:
|
|
223
223
|
return tf_integration.mlir_to_flatbuffer(self)
|
224
224
|
|
225
225
|
|
226
|
+
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
227
|
+
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
228
|
+
"""Convert internal constant aten ops' output from int64 to int32.
|
229
|
+
|
230
|
+
Int32 generally has better performance and compatibility than int64 in
|
231
|
+
runtime. This pass converts aten op where the output(s) are int64 constant
|
232
|
+
tensors to return int32 constant tensors.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
exported_program: The exported program to apply the pass.
|
236
|
+
"""
|
237
|
+
|
238
|
+
def in_i32(x: int):
|
239
|
+
return -2147483648 <= x <= 2147483647
|
240
|
+
|
241
|
+
def rewrite_arange(node: torch.fx.Node):
|
242
|
+
tensor_meta = node.meta.get("tensor_meta", None)
|
243
|
+
if not tensor_meta:
|
244
|
+
return
|
245
|
+
|
246
|
+
start, end = node.args[:2]
|
247
|
+
if tensor_meta.dtype != torch.int64:
|
248
|
+
return
|
249
|
+
if not (in_i32(start) and in_i32(end)):
|
250
|
+
return
|
251
|
+
op = node.target
|
252
|
+
node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
|
253
|
+
|
254
|
+
graph_module = exported_program.graph_module
|
255
|
+
for node in graph_module.graph.nodes:
|
256
|
+
|
257
|
+
if node.target == torch.ops.aten.arange.start_step:
|
258
|
+
rewrite_arange(node)
|
259
|
+
|
260
|
+
|
226
261
|
def exported_program_to_mlir(
|
227
262
|
exported_program: torch.export.ExportedProgram,
|
228
263
|
) -> MlirLowered:
|
@@ -231,6 +266,11 @@ def exported_program_to_mlir(
|
|
231
266
|
lowerings.decompositions()
|
232
267
|
)
|
233
268
|
|
269
|
+
_convert_i64_to_i32(exported_program)
|
270
|
+
exported_program = exported_program.run_decompositions(
|
271
|
+
lowerings.decompositions()
|
272
|
+
)
|
273
|
+
|
234
274
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
235
275
|
|
236
276
|
module = ir.Module.create()
|
@@ -202,3 +202,47 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
202
202
|
x, y = utils.broadcast_args_if_needed(x, y)
|
203
203
|
|
204
204
|
return stablehlo.divide(x, y)
|
205
|
+
|
206
|
+
|
207
|
+
# Schema:
|
208
|
+
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
|
209
|
+
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
210
|
+
# Torch Reference:
|
211
|
+
# - https://pytorch.org/docs/stable/generated/torch.slice_scatter.html
|
212
|
+
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
|
213
|
+
@lower(torch.ops.aten.slice_scatter)
|
214
|
+
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
215
|
+
start = start or 0
|
216
|
+
end = end or self.type.shape[dim]
|
217
|
+
if start < 0:
|
218
|
+
start = self.type.shape[dim] + start
|
219
|
+
if end < 0:
|
220
|
+
end = self.type.shape[dim] + end
|
221
|
+
|
222
|
+
end = start + step * math.ceil((end - start) / step) - (step - 1)
|
223
|
+
|
224
|
+
padding_low = start
|
225
|
+
padding_high = self.type.shape[dim] - end
|
226
|
+
|
227
|
+
rank = len(self.type.shape)
|
228
|
+
src = stablehlo.pad(
|
229
|
+
src,
|
230
|
+
utils.splat(0, src.type.element_type, []),
|
231
|
+
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
|
232
|
+
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
|
233
|
+
interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
|
234
|
+
)
|
235
|
+
pred = np.ones(self.type.shape, dtype=np.bool_)
|
236
|
+
pred[*[
|
237
|
+
slice(start, end, step) if i == dim else slice(None, None, None)
|
238
|
+
for i in range(rank)
|
239
|
+
]] = False
|
240
|
+
pred = stablehlo.constant(
|
241
|
+
ir.DenseElementsAttr.get(
|
242
|
+
np.packbits(pred, bitorder="little"),
|
243
|
+
type=ir.IntegerType.get_signless(1),
|
244
|
+
shape=pred.shape,
|
245
|
+
)
|
246
|
+
)
|
247
|
+
out = stablehlo.select(pred, self, src)
|
248
|
+
return out
|
@@ -203,7 +203,6 @@ lower_by_torch_xla2(torch.ops.aten.sin)
|
|
203
203
|
lower_by_torch_xla2(torch.ops.aten.sinh)
|
204
204
|
lower_by_torch_xla2(torch.ops.aten.slice)
|
205
205
|
lower_by_torch_xla2(torch.ops.aten.slice_copy)
|
206
|
-
lower_by_torch_xla2(torch.ops.aten.slice_scatter)
|
207
206
|
lower_by_torch_xla2(torch.ops.aten.sort)
|
208
207
|
lower_by_torch_xla2(torch.ops.aten.split)
|
209
208
|
lower_by_torch_xla2(torch.ops.aten.split_copy)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240914
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|