ai-edge-torch-nightly 0.3.0.dev20241213__py3-none-any.whl → 0.3.0.dev20241215__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/generative/layers/kv_cache.py +1 -2
- ai_edge_torch/generative/test/test_model_conversion.py +8 -9
- ai_edge_torch/generative/test/test_model_conversion_large.py +26 -27
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241215.dist-info}/METADATA +6 -4
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241215.dist-info}/RECORD +18 -19
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241215.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241215.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241215.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
@@ -13,13 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from ai_edge_torch._config import config
|
16
17
|
from ai_edge_torch._convert.converter import convert
|
17
18
|
from ai_edge_torch._convert.converter import signature
|
18
19
|
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
19
20
|
from ai_edge_torch.model import Model
|
20
21
|
from ai_edge_torch.version import __version__
|
21
22
|
|
22
|
-
|
23
23
|
def load(path: str) -> Model:
|
24
24
|
"""Imports an ai_edge_torch model from disk.
|
25
25
|
|
ai_edge_torch/_config.py
ADDED
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Provides a configuration for the ai-edge-torch."""
|
17
|
+
|
18
|
+
import functools
|
19
|
+
import logging
|
20
|
+
import os
|
21
|
+
|
22
|
+
__all__ = ["config"]
|
23
|
+
|
24
|
+
|
25
|
+
class _Config:
|
26
|
+
"""ai-edge-torch global configs."""
|
27
|
+
|
28
|
+
@property
|
29
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
30
|
+
def use_torch_xla(self) -> bool:
|
31
|
+
"""True if using torch_xla to lower torch ops to StableHLO.
|
32
|
+
|
33
|
+
To use torch_xla as the lowering backend, set environment variable
|
34
|
+
`USE_TORCH_XLA` to "true".
|
35
|
+
"""
|
36
|
+
var = os.environ.get("USE_TORCH_XLA", "false")
|
37
|
+
var = var.lower().strip()
|
38
|
+
if var in ("y", "yes", "t", "true", "on", "1"):
|
39
|
+
return True
|
40
|
+
elif var in ("n", "no", "f", "false", "off", "0"):
|
41
|
+
return False
|
42
|
+
else:
|
43
|
+
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
|
44
|
+
return False
|
45
|
+
|
46
|
+
@property
|
47
|
+
def in_oss(self) -> bool:
|
48
|
+
"""True if the code is not running in google internal environment."""
|
49
|
+
return True
|
50
|
+
|
51
|
+
|
52
|
+
config = _Config()
|
@@ -19,7 +19,6 @@ import os
|
|
19
19
|
from typing import Tuple
|
20
20
|
|
21
21
|
import ai_edge_torch
|
22
|
-
from ai_edge_torch import config
|
23
22
|
from ai_edge_torch._convert import conversion_utils
|
24
23
|
from ai_edge_torch.quantize import pt2e_quantizer
|
25
24
|
from ai_edge_torch.testing import model_coverage
|
@@ -292,7 +291,7 @@ class TestConvert(googletest.TestCase):
|
|
292
291
|
self.assertTrue(result)
|
293
292
|
|
294
293
|
@googletest.skipIf(
|
295
|
-
not config.
|
294
|
+
not ai_edge_torch.config.use_torch_xla,
|
296
295
|
reason="Shape polymorphism is not yet support with odml_torch.",
|
297
296
|
)
|
298
297
|
def test_convert_model_with_dynamic_batch(self):
|
@@ -20,6 +20,7 @@ from typing import List, Tuple
|
|
20
20
|
|
21
21
|
from ai_edge_torch import hlfb
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
|
+
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
|
23
24
|
import torch
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
@@ -159,8 +160,6 @@ def update(
|
|
159
160
|
Returns:
|
160
161
|
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
161
162
|
"""
|
162
|
-
# Turn dynamic_update_slice updates off for now.
|
163
|
-
use_dus=False
|
164
163
|
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
|
165
164
|
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
166
165
|
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
21
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
21
|
from ai_edge_torch.generative.layers import kv_cache
|
@@ -83,22 +82,22 @@ class TestModelConversion(googletest.TestCase):
|
|
83
82
|
)
|
84
83
|
|
85
84
|
@googletest.skipIf(
|
86
|
-
|
87
|
-
reason="tests with custom ops are not supported
|
85
|
+
ai_edge_torch.config.in_oss,
|
86
|
+
reason="tests with custom ops are not supported in oss",
|
88
87
|
)
|
89
88
|
def test_toy_model_with_kv_cache(self):
|
90
89
|
self._test_model_with_kv_cache(enable_hlfb=False)
|
91
90
|
|
92
91
|
@googletest.skipIf(
|
93
|
-
|
94
|
-
reason="tests with custom ops are not supported
|
92
|
+
ai_edge_torch.config.in_oss,
|
93
|
+
reason="tests with custom ops are not supported in oss",
|
95
94
|
)
|
96
95
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
97
96
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
98
97
|
|
99
98
|
@googletest.skipIf(
|
100
|
-
|
101
|
-
reason="tests with custom ops are not supported
|
99
|
+
ai_edge_torch.config.in_oss,
|
100
|
+
reason="tests with custom ops are not supported in oss",
|
102
101
|
)
|
103
102
|
def test_toy_model_has_dus_op(self):
|
104
103
|
"""Tests that the model has the dynamic update slice op."""
|
@@ -179,8 +178,8 @@ class TestModelConversion(googletest.TestCase):
|
|
179
178
|
)
|
180
179
|
|
181
180
|
@googletest.skipIf(
|
182
|
-
|
183
|
-
reason="tests with custom ops are not supported
|
181
|
+
ai_edge_torch.config.in_oss,
|
182
|
+
reason="tests with custom ops are not supported in oss",
|
184
183
|
)
|
185
184
|
def test_tiny_llama_multisig(self):
|
186
185
|
config = tiny_llama.get_fake_model_config()
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
21
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
@@ -91,8 +90,8 @@ class TestModelConversion(googletest.TestCase):
|
|
91
90
|
)
|
92
91
|
|
93
92
|
@googletest.skipIf(
|
94
|
-
|
95
|
-
reason="tests with custom ops are not supported
|
93
|
+
ai_edge_torch.config.in_oss,
|
94
|
+
reason="tests with custom ops are not supported in oss",
|
96
95
|
)
|
97
96
|
def test_gemma1(self):
|
98
97
|
config = gemma1.get_fake_model_config()
|
@@ -100,8 +99,8 @@ class TestModelConversion(googletest.TestCase):
|
|
100
99
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
101
100
|
|
102
101
|
@googletest.skipIf(
|
103
|
-
|
104
|
-
reason="tests with custom ops are not supported
|
102
|
+
ai_edge_torch.config.in_oss,
|
103
|
+
reason="tests with custom ops are not supported in oss",
|
105
104
|
)
|
106
105
|
def test_gemma2(self):
|
107
106
|
config = gemma2.get_fake_model_config()
|
@@ -109,8 +108,8 @@ class TestModelConversion(googletest.TestCase):
|
|
109
108
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
110
109
|
|
111
110
|
@googletest.skipIf(
|
112
|
-
|
113
|
-
reason="tests with custom ops are not supported
|
111
|
+
ai_edge_torch.config.in_oss,
|
112
|
+
reason="tests with custom ops are not supported in oss",
|
114
113
|
)
|
115
114
|
def test_llama(self):
|
116
115
|
config = llama.get_fake_model_config()
|
@@ -118,8 +117,8 @@ class TestModelConversion(googletest.TestCase):
|
|
118
117
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
119
118
|
|
120
119
|
@googletest.skipIf(
|
121
|
-
|
122
|
-
reason="tests with custom ops are not supported
|
120
|
+
ai_edge_torch.config.in_oss,
|
121
|
+
reason="tests with custom ops are not supported in oss",
|
123
122
|
)
|
124
123
|
def test_phi2(self):
|
125
124
|
config = phi2.get_fake_model_config()
|
@@ -128,8 +127,8 @@ class TestModelConversion(googletest.TestCase):
|
|
128
127
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
129
128
|
|
130
129
|
@googletest.skipIf(
|
131
|
-
|
132
|
-
reason="tests with custom ops are not supported
|
130
|
+
ai_edge_torch.config.in_oss,
|
131
|
+
reason="tests with custom ops are not supported in oss",
|
133
132
|
)
|
134
133
|
def test_phi3(self):
|
135
134
|
config = phi3.get_fake_model_config()
|
@@ -137,8 +136,8 @@ class TestModelConversion(googletest.TestCase):
|
|
137
136
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
138
137
|
|
139
138
|
@googletest.skipIf(
|
140
|
-
|
141
|
-
reason="tests with custom ops are not supported
|
139
|
+
ai_edge_torch.config.in_oss,
|
140
|
+
reason="tests with custom ops are not supported in oss",
|
142
141
|
)
|
143
142
|
def test_smollm(self):
|
144
143
|
config = smollm.get_fake_model_config()
|
@@ -146,8 +145,8 @@ class TestModelConversion(googletest.TestCase):
|
|
146
145
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
147
146
|
|
148
147
|
@googletest.skipIf(
|
149
|
-
|
150
|
-
reason="tests with custom ops are not supported
|
148
|
+
ai_edge_torch.config.in_oss,
|
149
|
+
reason="tests with custom ops are not supported in oss",
|
151
150
|
)
|
152
151
|
def test_openelm(self):
|
153
152
|
config = openelm.get_fake_model_config()
|
@@ -155,8 +154,8 @@ class TestModelConversion(googletest.TestCase):
|
|
155
154
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
156
155
|
|
157
156
|
@googletest.skipIf(
|
158
|
-
|
159
|
-
reason="tests with custom ops are not supported
|
157
|
+
ai_edge_torch.config.in_oss,
|
158
|
+
reason="tests with custom ops are not supported in oss",
|
160
159
|
)
|
161
160
|
def test_qwen(self):
|
162
161
|
config = qwen.get_fake_model_config()
|
@@ -164,8 +163,8 @@ class TestModelConversion(googletest.TestCase):
|
|
164
163
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
165
164
|
|
166
165
|
@googletest.skipIf(
|
167
|
-
|
168
|
-
reason="tests with custom ops are not supported
|
166
|
+
ai_edge_torch.config.in_oss,
|
167
|
+
reason="tests with custom ops are not supported in oss",
|
169
168
|
)
|
170
169
|
def test_amd_llama_135m(self):
|
171
170
|
config = amd_llama_135m.get_fake_model_config()
|
@@ -173,8 +172,8 @@ class TestModelConversion(googletest.TestCase):
|
|
173
172
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
174
173
|
|
175
174
|
@googletest.skipIf(
|
176
|
-
|
177
|
-
reason="tests with custom ops are not supported
|
175
|
+
ai_edge_torch.config.in_oss,
|
176
|
+
reason="tests with custom ops are not supported in oss",
|
178
177
|
)
|
179
178
|
def disabled_test_paligemma(self):
|
180
179
|
config = paligemma.get_fake_model_config()
|
@@ -222,8 +221,8 @@ class TestModelConversion(googletest.TestCase):
|
|
222
221
|
)
|
223
222
|
|
224
223
|
@googletest.skipIf(
|
225
|
-
|
226
|
-
reason="tests with custom ops are not supported
|
224
|
+
ai_edge_torch.config.in_oss,
|
225
|
+
reason="tests with custom ops are not supported in oss",
|
227
226
|
)
|
228
227
|
def test_stable_diffusion_clip(self):
|
229
228
|
config = sd_clip.get_fake_model_config()
|
@@ -254,8 +253,8 @@ class TestModelConversion(googletest.TestCase):
|
|
254
253
|
)
|
255
254
|
|
256
255
|
@googletest.skipIf(
|
257
|
-
|
258
|
-
reason="tests with custom ops are not supported
|
256
|
+
ai_edge_torch.config.in_oss,
|
257
|
+
reason="tests with custom ops are not supported in oss",
|
259
258
|
)
|
260
259
|
def test_stable_diffusion_diffusion(self):
|
261
260
|
config = sd_diffusion.get_fake_model_config(2)
|
@@ -296,8 +295,8 @@ class TestModelConversion(googletest.TestCase):
|
|
296
295
|
)
|
297
296
|
|
298
297
|
@googletest.skipIf(
|
299
|
-
|
300
|
-
reason="tests with custom ops are not supported
|
298
|
+
ai_edge_torch.config.in_oss,
|
299
|
+
reason="tests with custom ops are not supported in oss",
|
301
300
|
)
|
302
301
|
def test_stable_diffusion_decoder(self):
|
303
302
|
config = sd_decoder.get_fake_model_config()
|
@@ -15,13 +15,15 @@
|
|
15
15
|
|
16
16
|
from typing import Any, Optional
|
17
17
|
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from ai_edge_torch._convert import signature
|
20
20
|
from ai_edge_torch.quantize import quant_config as qcfg
|
21
21
|
import torch
|
22
22
|
|
23
|
+
config = _config.config
|
24
|
+
|
23
25
|
# isort: off
|
24
|
-
if config.
|
26
|
+
if config.use_torch_xla:
|
25
27
|
from ai_edge_torch.lowertools import torch_xla_utils as utils
|
26
28
|
from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
|
27
29
|
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import re
|
17
17
|
from typing import Optional
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from absl.testing import absltest as googletest
|
20
20
|
|
21
|
+
config = _config.config
|
22
|
+
|
21
23
|
|
22
24
|
def _extract_backend_configs(mlir):
|
23
25
|
mlir = mlir.replace("\\22", '"')
|
@@ -38,7 +40,7 @@ def assert_string_count(
|
|
38
40
|
if odml_torch_attr_counter is None:
|
39
41
|
odml_torch_attr_counter = {}
|
40
42
|
|
41
|
-
if config.
|
43
|
+
if config.use_torch_xla:
|
42
44
|
for key in torch_xla_pattern_counter:
|
43
45
|
test_case.assertEqual(
|
44
46
|
mlir.count(key),
|
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
276
276
|
interior_padding if i == dim else 0 for i in range(rank)
|
277
277
|
],
|
278
278
|
)
|
279
|
-
|
280
|
-
|
279
|
+
|
280
|
+
slices = [
|
281
281
|
slice(start, end, step) if i == dim else slice(None, None, None)
|
282
282
|
for i in range(rank)
|
283
|
-
]
|
283
|
+
]
|
284
|
+
pred = np.ones(self.type.shape, dtype=np.bool_)
|
285
|
+
pred[np.index_exp[tuple(slices)]] = False
|
284
286
|
pred = stablehlo.constant(
|
285
287
|
ir.DenseElementsAttr.get(
|
286
288
|
np.packbits(pred, bitorder="little"),
|
@@ -232,7 +232,9 @@ def _aten_convolution(
|
|
232
232
|
|
233
233
|
if bias is not None:
|
234
234
|
# broadcast [C] to [NCHW]
|
235
|
-
broadcasted_bias = stablehlo.broadcast_in_dim(
|
235
|
+
broadcasted_bias = stablehlo.broadcast_in_dim(
|
236
|
+
output_type, bias, ir.DenseI64ArrayAttr.get([1])
|
237
|
+
)
|
236
238
|
res = stablehlo.add(
|
237
239
|
lhs=res,
|
238
240
|
rhs=broadcasted_bias,
|
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
|
|
20
20
|
from ai_edge_torch.odml_torch.lowerings import utils
|
21
21
|
from jax._src.lib.mlir import ir
|
22
22
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import numpy as np
|
23
24
|
import torch
|
24
25
|
|
25
26
|
|
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
|
|
66
67
|
normalized_rank = len(normalized_shape)
|
67
68
|
if weight is not None:
|
68
69
|
weight = stablehlo.broadcast_in_dim(
|
69
|
-
data_type,
|
70
|
+
data_type,
|
71
|
+
weight,
|
72
|
+
ir.DenseI64ArrayAttr.get(
|
73
|
+
list(range(data_rank - normalized_rank, data_rank))
|
74
|
+
),
|
70
75
|
)
|
71
76
|
output = stablehlo.multiply(weight, output)
|
72
77
|
if bias is not None:
|
73
78
|
bias = stablehlo.broadcast_in_dim(
|
74
|
-
data_type,
|
79
|
+
data_type,
|
80
|
+
bias,
|
81
|
+
ir.DenseI64ArrayAttr.get(
|
82
|
+
list(range(data_rank - normalized_rank, data_rank))
|
83
|
+
),
|
75
84
|
)
|
76
85
|
output = stablehlo.add(bias, output)
|
77
86
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
"""Lowerings for PT2E torch.ops.quantized_decomposed ops."""
|
16
|
-
from typing import Union, cast
|
16
|
+
from typing import Optional, Union, cast
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch.lowerings import context
|
19
19
|
from ai_edge_torch.odml_torch.lowerings import utils
|
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
|
|
30
30
|
|
31
31
|
|
32
32
|
def _uniform_quantized_type(
|
33
|
-
stored_type: str
|
34
|
-
expressed_type: str
|
33
|
+
stored_type: Union[str, ir.Type],
|
34
|
+
expressed_type: Union[str, ir.Type],
|
35
35
|
*,
|
36
|
-
scale=float
|
37
|
-
zero_point=float
|
38
|
-
storage_type_min: int
|
39
|
-
storage_type_max: int
|
40
|
-
channel_axis: int
|
41
|
-
channel_axis_size: int
|
36
|
+
scale=Union[float, list[float], tuple[float]],
|
37
|
+
zero_point=Union[float, list[float], tuple[float]],
|
38
|
+
storage_type_min: Optional[int] = None,
|
39
|
+
storage_type_max: Optional[int] = None,
|
40
|
+
channel_axis: Optional[int] = None,
|
41
|
+
channel_axis_size: Optional[int] = None,
|
42
42
|
):
|
43
43
|
"""Polyfill for quant.UniformQuantizedType."""
|
44
44
|
if storage_type_min and storage_type_max:
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241215
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
|
|
11
11
|
Classifier: License :: OSI Approved :: Apache Software License
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
13
13
|
Classifier: Programming Language :: Python :: 3 :: Only
|
14
|
-
Classifier: Programming Language :: Python :: 3.9
|
15
14
|
Classifier: Programming Language :: Python :: 3.10
|
16
15
|
Classifier: Programming Language :: Python :: 3.11
|
17
16
|
Classifier: Topic :: Scientific/Engineering
|
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
19
|
Classifier: Topic :: Software Development
|
21
20
|
Classifier: Topic :: Software Development :: Libraries
|
22
21
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23
|
-
Requires-Python: >=3.
|
22
|
+
Requires-Python: >=3.10
|
24
23
|
Description-Content-Type: text/markdown
|
25
24
|
License-File: LICENSE
|
26
25
|
Requires-Dist: numpy
|
@@ -28,10 +27,13 @@ Requires-Dist: scipy
|
|
28
27
|
Requires-Dist: safetensors
|
29
28
|
Requires-Dist: tabulate
|
30
29
|
Requires-Dist: torch>=2.4.0
|
31
|
-
Requires-Dist: torch-xla>=2.4.0
|
32
30
|
Requires-Dist: tf-nightly>=2.19.0.dev20241201
|
33
31
|
Requires-Dist: ai-edge-litert-nightly
|
34
32
|
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: jax
|
34
|
+
Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
|
35
|
+
Provides-Extra: torch-xla
|
36
|
+
Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
|
35
37
|
|
36
38
|
Library that supports converting PyTorch models into a .tflite format, which can
|
37
39
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
@@ -1,9 +1,9 @@
|
|
1
|
-
ai_edge_torch/__init__.py,sha256=
|
2
|
-
ai_edge_torch/
|
1
|
+
ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
|
2
|
+
ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=K6IQHV_-ygm-XHO2-Za1f4YtOckCWkp3RoVrufaooRk,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=gK9QJuLbpjXt0l6tVnzl9Miq6GLkJR-hB67i3VE13Og,17224
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -119,7 +119,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoT
|
|
119
119
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
120
120
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
121
121
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
122
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
122
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
123
123
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
124
124
|
ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
|
125
125
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
@@ -139,8 +139,8 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
139
139
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
|
140
140
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
|
141
141
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
142
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
143
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
142
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
143
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
|
144
144
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
145
145
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
146
146
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -159,12 +159,11 @@ ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge
|
|
159
159
|
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
|
160
160
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
161
161
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
162
|
-
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
|
163
162
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
164
|
-
ai_edge_torch/lowertools/_shim.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/_shim.py,sha256=xJIHDSWNoF4PkkT0JkjeJxgguQ9JGEwooJf9xZNkVRU,3058
|
165
164
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
166
165
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
|
167
|
-
ai_edge_torch/lowertools/test_utils.py,sha256=
|
166
|
+
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
168
167
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
|
169
168
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
170
169
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
@@ -183,12 +182,12 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
|
|
183
182
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
184
183
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
185
184
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
|
186
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
185
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
|
187
186
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
188
|
-
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=
|
187
|
+
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
189
188
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
190
|
-
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=
|
191
|
-
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=
|
189
|
+
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
190
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=GEs83mtEjh8GOW_OATI_ur11VKujrOL2xdZeZ0l1HtM,6100
|
192
191
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
193
192
|
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
194
193
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
@@ -201,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
201
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
202
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
203
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/METADATA,sha256=kkijdPdACWUh6ocM7K99XNhICV9dA4uH3KlQZ-R2NFg,1966
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241215.dist-info/RECORD,,
|
ai_edge_torch/config.py
DELETED
@@ -1,27 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Provides a configuration for the AI Edge Torch library."""
|
17
|
-
|
18
|
-
import dataclasses
|
19
|
-
import os
|
20
|
-
|
21
|
-
|
22
|
-
@dataclasses.dataclass
|
23
|
-
class Config:
|
24
|
-
use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
|
25
|
-
"1",
|
26
|
-
"true",
|
27
|
-
)
|
@@ -1,283 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Tests for StableHLOCompositeBuilder."""
|
16
|
-
|
17
|
-
import math
|
18
|
-
|
19
|
-
from ai_edge_torch import config
|
20
|
-
from ai_edge_torch import lowertools
|
21
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
22
|
-
import torch
|
23
|
-
import torch.nn.functional as F
|
24
|
-
|
25
|
-
from absl.testing import absltest as googletest
|
26
|
-
|
27
|
-
|
28
|
-
def _export_stablehlo_mlir(model, args):
|
29
|
-
ep = torch.export.export(model, args)
|
30
|
-
return lowertools.exported_program_to_mlir_text(ep)
|
31
|
-
|
32
|
-
|
33
|
-
@googletest.skipIf(
|
34
|
-
not config.Config.use_torch_xla,
|
35
|
-
reason="The odml_torch counter part is in odml_torch.",
|
36
|
-
)
|
37
|
-
class TestStableHLOCompositeBuilder(googletest.TestCase):
|
38
|
-
|
39
|
-
def test_build_composite(self):
|
40
|
-
class SampleModel(torch.nn.Module):
|
41
|
-
|
42
|
-
def forward(self, x):
|
43
|
-
builder = StableHLOCompositeBuilder(name="test.plus_two")
|
44
|
-
y = x + 1
|
45
|
-
y = builder.mark_inputs(y)
|
46
|
-
z = y + 2
|
47
|
-
z = builder.mark_outputs(z)
|
48
|
-
return z
|
49
|
-
|
50
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
51
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1)
|
52
|
-
|
53
|
-
def test_build_multiple_composites(self):
|
54
|
-
class SampleModel(torch.nn.Module):
|
55
|
-
|
56
|
-
def plus_one(self, x: torch.Tensor):
|
57
|
-
builder = StableHLOCompositeBuilder("test.plus_one")
|
58
|
-
x = builder.mark_inputs(x)
|
59
|
-
y = x + 1
|
60
|
-
y = builder.mark_outputs(y)
|
61
|
-
return y
|
62
|
-
|
63
|
-
def plus_two(self, x: torch.Tensor):
|
64
|
-
builder = StableHLOCompositeBuilder("test.plus_two")
|
65
|
-
x = builder.mark_inputs(x)
|
66
|
-
y = x + 2
|
67
|
-
y = builder.mark_outputs(y)
|
68
|
-
return y
|
69
|
-
|
70
|
-
def forward(self, x):
|
71
|
-
x = self.plus_two(x)
|
72
|
-
x = x + 3
|
73
|
-
x = self.plus_one(x)
|
74
|
-
x = x + 4
|
75
|
-
x = self.plus_two(x)
|
76
|
-
return x
|
77
|
-
|
78
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
79
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1)
|
80
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2)
|
81
|
-
|
82
|
-
def test_build_composite_with_attr(self):
|
83
|
-
class SampleModel(torch.nn.Module):
|
84
|
-
|
85
|
-
def __init__(self):
|
86
|
-
super().__init__()
|
87
|
-
|
88
|
-
def log_softmax(self, x: torch.Tensor, dim: int):
|
89
|
-
builder = StableHLOCompositeBuilder(
|
90
|
-
name="test.log_softmax", attr={"dim": dim}
|
91
|
-
)
|
92
|
-
x = builder.mark_inputs(x)
|
93
|
-
y = torch.nn.functional.log_softmax(x, dim=dim)
|
94
|
-
y = builder.mark_outputs(y)
|
95
|
-
return y
|
96
|
-
|
97
|
-
def forward(self, x):
|
98
|
-
x = x + 1
|
99
|
-
x = self.log_softmax(x, 0)
|
100
|
-
x = self.log_softmax(x, 1)
|
101
|
-
return x
|
102
|
-
|
103
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
104
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2)
|
105
|
-
self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1)
|
106
|
-
self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1)
|
107
|
-
|
108
|
-
def test_build_composite_with_mix_type_attrs(self):
|
109
|
-
class SampleModel(torch.nn.Module):
|
110
|
-
|
111
|
-
def __init__(self):
|
112
|
-
super().__init__()
|
113
|
-
|
114
|
-
def log_softmax(self, x: torch.Tensor, dim: int):
|
115
|
-
builder = StableHLOCompositeBuilder(
|
116
|
-
name="test.log_softmax",
|
117
|
-
attr={
|
118
|
-
"dim": dim,
|
119
|
-
"source": "torch.nn",
|
120
|
-
"version": 1.0,
|
121
|
-
},
|
122
|
-
)
|
123
|
-
x = builder.mark_inputs(x)
|
124
|
-
y = torch.nn.functional.log_softmax(x, dim=dim)
|
125
|
-
y = builder.mark_outputs(y)
|
126
|
-
return y
|
127
|
-
|
128
|
-
def forward(self, x):
|
129
|
-
x = x + 1
|
130
|
-
x = self.log_softmax(x, 0)
|
131
|
-
return x
|
132
|
-
|
133
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
134
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
|
135
|
-
self.assertEqual(
|
136
|
-
mlir.count(
|
137
|
-
'composite_attributes = {dim = 0 : i64, source = "torch.nn",'
|
138
|
-
" version = 1.000000e+00 : f32}"
|
139
|
-
),
|
140
|
-
1,
|
141
|
-
)
|
142
|
-
|
143
|
-
def test_sdpa_composite(self):
|
144
|
-
class SDPAModel(torch.nn.Module):
|
145
|
-
|
146
|
-
def scaled_dot_product_attention(
|
147
|
-
self,
|
148
|
-
q: torch.Tensor,
|
149
|
-
k: torch.Tensor,
|
150
|
-
v: torch.Tensor,
|
151
|
-
head_size: int,
|
152
|
-
mask: torch.Tensor,
|
153
|
-
):
|
154
|
-
builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention")
|
155
|
-
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
156
|
-
|
157
|
-
scale = 1.0 / math.sqrt(head_size)
|
158
|
-
|
159
|
-
q = q.transpose(1, 2)
|
160
|
-
k = k.transpose(1, 2)
|
161
|
-
v = v.transpose(1, 2)
|
162
|
-
y = F.scaled_dot_product_attention(
|
163
|
-
q,
|
164
|
-
k,
|
165
|
-
v,
|
166
|
-
attn_mask=mask,
|
167
|
-
dropout_p=0.0,
|
168
|
-
is_causal=mask is None,
|
169
|
-
scale=scale,
|
170
|
-
)
|
171
|
-
result = y.transpose(1, 2)
|
172
|
-
result = builder.mark_outputs(result)
|
173
|
-
return result
|
174
|
-
|
175
|
-
def forward(self, q, k, v, mask):
|
176
|
-
x = self.scaled_dot_product_attention(
|
177
|
-
q,
|
178
|
-
k,
|
179
|
-
v,
|
180
|
-
8,
|
181
|
-
mask,
|
182
|
-
)
|
183
|
-
return x
|
184
|
-
|
185
|
-
query = torch.rand(1, 1, 32, 4)
|
186
|
-
key = torch.rand(1, 500, 1, 4)
|
187
|
-
value = torch.rand(1, 500, 1, 4)
|
188
|
-
mask = torch.rand(1, 1, 1, 500)
|
189
|
-
|
190
|
-
mlir = _export_stablehlo_mlir(
|
191
|
-
SDPAModel().eval(),
|
192
|
-
(query, key, value, mask),
|
193
|
-
)
|
194
|
-
self.assertEqual(
|
195
|
-
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1
|
196
|
-
)
|
197
|
-
|
198
|
-
def test_sdpa_composite_with_attr(self):
|
199
|
-
class SDPAModel(torch.nn.Module):
|
200
|
-
|
201
|
-
def scaled_dot_product_attention(
|
202
|
-
self,
|
203
|
-
q: torch.Tensor,
|
204
|
-
k: torch.Tensor,
|
205
|
-
v: torch.Tensor,
|
206
|
-
head_size: int,
|
207
|
-
include_captanh: bool,
|
208
|
-
):
|
209
|
-
builder = StableHLOCompositeBuilder(
|
210
|
-
name="test.scaled_dot_product_attention",
|
211
|
-
attr={"include_captanh": include_captanh},
|
212
|
-
)
|
213
|
-
q, k, v = builder.mark_inputs(q, k, v)
|
214
|
-
|
215
|
-
scale = 1.0 / math.sqrt(head_size)
|
216
|
-
|
217
|
-
q = q.transpose(1, 2)
|
218
|
-
k = k.transpose(1, 2)
|
219
|
-
v = v.transpose(1, 2)
|
220
|
-
y = F.scaled_dot_product_attention(
|
221
|
-
q,
|
222
|
-
k,
|
223
|
-
v,
|
224
|
-
attn_mask=None,
|
225
|
-
dropout_p=0.0,
|
226
|
-
is_causal=True,
|
227
|
-
scale=scale,
|
228
|
-
)
|
229
|
-
result = y.transpose(1, 2)
|
230
|
-
result = builder.mark_outputs(result)
|
231
|
-
return result
|
232
|
-
|
233
|
-
def forward(self, q, k, v):
|
234
|
-
x = self.scaled_dot_product_attention(q, k, v, 8, True)
|
235
|
-
y = self.scaled_dot_product_attention(q, k, v, 8, False)
|
236
|
-
return x + y
|
237
|
-
|
238
|
-
query = torch.rand(1, 1, 32, 4)
|
239
|
-
key = torch.rand(1, 500, 1, 4)
|
240
|
-
value = torch.rand(1, 500, 1, 4)
|
241
|
-
mlir = _export_stablehlo_mlir(
|
242
|
-
SDPAModel().eval(),
|
243
|
-
(query, key, value),
|
244
|
-
)
|
245
|
-
self.assertEqual(
|
246
|
-
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
|
247
|
-
)
|
248
|
-
self.assertEqual(
|
249
|
-
mlir.count("composite_attributes = {include_captanh = true}"), 1
|
250
|
-
)
|
251
|
-
self.assertEqual(
|
252
|
-
mlir.count("composite_attributes = {include_captanh = false}"), 1
|
253
|
-
)
|
254
|
-
|
255
|
-
def test_build_composite_with_multiple_inputs_outputs(self):
|
256
|
-
class SampleModel(torch.nn.Module):
|
257
|
-
|
258
|
-
def mimo_sample(self, a, b, c):
|
259
|
-
builder = StableHLOCompositeBuilder(name="test.mimo_sample")
|
260
|
-
|
261
|
-
a, b, c = builder.mark_inputs(a, b, c)
|
262
|
-
x = a + b + c
|
263
|
-
y = (a - b) * x
|
264
|
-
z = (c + 1.0) * a
|
265
|
-
x, y, z = builder.mark_outputs(x, y, z)
|
266
|
-
|
267
|
-
result = x + y * z
|
268
|
-
return result
|
269
|
-
|
270
|
-
def forward(self, a, b, c):
|
271
|
-
x = self.mimo_sample(a, b, c)
|
272
|
-
x = self.mimo_sample(a, b, x)
|
273
|
-
x = self.mimo_sample(x, x, c)
|
274
|
-
return x
|
275
|
-
|
276
|
-
mlir = _export_stablehlo_mlir(
|
277
|
-
SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2))
|
278
|
-
)
|
279
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3)
|
280
|
-
|
281
|
-
|
282
|
-
if __name__ == "__main__":
|
283
|
-
googletest.main()
|
File without changes
|
File without changes
|