ai-edge-torch-nightly 0.3.0.dev20241213__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/generative/layers/kv_cache.py +1 -2
- ai_edge_torch/generative/test/test_model_conversion.py +8 -9
- ai_edge_torch/generative/test/test_model_conversion_large.py +26 -27
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +6 -4
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +18 -19
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241213.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
@@ -13,13 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from ai_edge_torch._config import config
|
16
17
|
from ai_edge_torch._convert.converter import convert
|
17
18
|
from ai_edge_torch._convert.converter import signature
|
18
19
|
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
19
20
|
from ai_edge_torch.model import Model
|
20
21
|
from ai_edge_torch.version import __version__
|
21
22
|
|
22
|
-
|
23
23
|
def load(path: str) -> Model:
|
24
24
|
"""Imports an ai_edge_torch model from disk.
|
25
25
|
|
ai_edge_torch/_config.py
ADDED
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Provides a configuration for the ai-edge-torch."""
|
17
|
+
|
18
|
+
import functools
|
19
|
+
import logging
|
20
|
+
import os
|
21
|
+
|
22
|
+
__all__ = ["config"]
|
23
|
+
|
24
|
+
|
25
|
+
class _Config:
|
26
|
+
"""ai-edge-torch global configs."""
|
27
|
+
|
28
|
+
@property
|
29
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
30
|
+
def use_torch_xla(self) -> bool:
|
31
|
+
"""True if using torch_xla to lower torch ops to StableHLO.
|
32
|
+
|
33
|
+
To use torch_xla as the lowering backend, set environment variable
|
34
|
+
`USE_TORCH_XLA` to "true".
|
35
|
+
"""
|
36
|
+
var = os.environ.get("USE_TORCH_XLA", "false")
|
37
|
+
var = var.lower().strip()
|
38
|
+
if var in ("y", "yes", "t", "true", "on", "1"):
|
39
|
+
return True
|
40
|
+
elif var in ("n", "no", "f", "false", "off", "0"):
|
41
|
+
return False
|
42
|
+
else:
|
43
|
+
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
|
44
|
+
return False
|
45
|
+
|
46
|
+
@property
|
47
|
+
def in_oss(self) -> bool:
|
48
|
+
"""True if the code is not running in google internal environment."""
|
49
|
+
return True
|
50
|
+
|
51
|
+
|
52
|
+
config = _Config()
|
@@ -19,7 +19,6 @@ import os
|
|
19
19
|
from typing import Tuple
|
20
20
|
|
21
21
|
import ai_edge_torch
|
22
|
-
from ai_edge_torch import config
|
23
22
|
from ai_edge_torch._convert import conversion_utils
|
24
23
|
from ai_edge_torch.quantize import pt2e_quantizer
|
25
24
|
from ai_edge_torch.testing import model_coverage
|
@@ -292,7 +291,7 @@ class TestConvert(googletest.TestCase):
|
|
292
291
|
self.assertTrue(result)
|
293
292
|
|
294
293
|
@googletest.skipIf(
|
295
|
-
not config.
|
294
|
+
not ai_edge_torch.config.use_torch_xla,
|
296
295
|
reason="Shape polymorphism is not yet support with odml_torch.",
|
297
296
|
)
|
298
297
|
def test_convert_model_with_dynamic_batch(self):
|
@@ -20,6 +20,7 @@ from typing import List, Tuple
|
|
20
20
|
|
21
21
|
from ai_edge_torch import hlfb
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
|
+
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
|
23
24
|
import torch
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
@@ -159,8 +160,6 @@ def update(
|
|
159
160
|
Returns:
|
160
161
|
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
161
162
|
"""
|
162
|
-
# Turn dynamic_update_slice updates off for now.
|
163
|
-
use_dus=False
|
164
163
|
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
|
165
164
|
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
166
165
|
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
21
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
21
|
from ai_edge_torch.generative.layers import kv_cache
|
@@ -83,22 +82,22 @@ class TestModelConversion(googletest.TestCase):
|
|
83
82
|
)
|
84
83
|
|
85
84
|
@googletest.skipIf(
|
86
|
-
|
87
|
-
reason="tests with custom ops are not supported
|
85
|
+
ai_edge_torch.config.in_oss,
|
86
|
+
reason="tests with custom ops are not supported in oss",
|
88
87
|
)
|
89
88
|
def test_toy_model_with_kv_cache(self):
|
90
89
|
self._test_model_with_kv_cache(enable_hlfb=False)
|
91
90
|
|
92
91
|
@googletest.skipIf(
|
93
|
-
|
94
|
-
reason="tests with custom ops are not supported
|
92
|
+
ai_edge_torch.config.in_oss,
|
93
|
+
reason="tests with custom ops are not supported in oss",
|
95
94
|
)
|
96
95
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
97
96
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
98
97
|
|
99
98
|
@googletest.skipIf(
|
100
|
-
|
101
|
-
reason="tests with custom ops are not supported
|
99
|
+
ai_edge_torch.config.in_oss,
|
100
|
+
reason="tests with custom ops are not supported in oss",
|
102
101
|
)
|
103
102
|
def test_toy_model_has_dus_op(self):
|
104
103
|
"""Tests that the model has the dynamic update slice op."""
|
@@ -179,8 +178,8 @@ class TestModelConversion(googletest.TestCase):
|
|
179
178
|
)
|
180
179
|
|
181
180
|
@googletest.skipIf(
|
182
|
-
|
183
|
-
reason="tests with custom ops are not supported
|
181
|
+
ai_edge_torch.config.in_oss,
|
182
|
+
reason="tests with custom ops are not supported in oss",
|
184
183
|
)
|
185
184
|
def test_tiny_llama_multisig(self):
|
186
185
|
config = tiny_llama.get_fake_model_config()
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
21
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
@@ -91,8 +90,8 @@ class TestModelConversion(googletest.TestCase):
|
|
91
90
|
)
|
92
91
|
|
93
92
|
@googletest.skipIf(
|
94
|
-
|
95
|
-
reason="tests with custom ops are not supported
|
93
|
+
ai_edge_torch.config.in_oss,
|
94
|
+
reason="tests with custom ops are not supported in oss",
|
96
95
|
)
|
97
96
|
def test_gemma1(self):
|
98
97
|
config = gemma1.get_fake_model_config()
|
@@ -100,8 +99,8 @@ class TestModelConversion(googletest.TestCase):
|
|
100
99
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
101
100
|
|
102
101
|
@googletest.skipIf(
|
103
|
-
|
104
|
-
reason="tests with custom ops are not supported
|
102
|
+
ai_edge_torch.config.in_oss,
|
103
|
+
reason="tests with custom ops are not supported in oss",
|
105
104
|
)
|
106
105
|
def test_gemma2(self):
|
107
106
|
config = gemma2.get_fake_model_config()
|
@@ -109,8 +108,8 @@ class TestModelConversion(googletest.TestCase):
|
|
109
108
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
110
109
|
|
111
110
|
@googletest.skipIf(
|
112
|
-
|
113
|
-
reason="tests with custom ops are not supported
|
111
|
+
ai_edge_torch.config.in_oss,
|
112
|
+
reason="tests with custom ops are not supported in oss",
|
114
113
|
)
|
115
114
|
def test_llama(self):
|
116
115
|
config = llama.get_fake_model_config()
|
@@ -118,8 +117,8 @@ class TestModelConversion(googletest.TestCase):
|
|
118
117
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
119
118
|
|
120
119
|
@googletest.skipIf(
|
121
|
-
|
122
|
-
reason="tests with custom ops are not supported
|
120
|
+
ai_edge_torch.config.in_oss,
|
121
|
+
reason="tests with custom ops are not supported in oss",
|
123
122
|
)
|
124
123
|
def test_phi2(self):
|
125
124
|
config = phi2.get_fake_model_config()
|
@@ -128,8 +127,8 @@ class TestModelConversion(googletest.TestCase):
|
|
128
127
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
129
128
|
|
130
129
|
@googletest.skipIf(
|
131
|
-
|
132
|
-
reason="tests with custom ops are not supported
|
130
|
+
ai_edge_torch.config.in_oss,
|
131
|
+
reason="tests with custom ops are not supported in oss",
|
133
132
|
)
|
134
133
|
def test_phi3(self):
|
135
134
|
config = phi3.get_fake_model_config()
|
@@ -137,8 +136,8 @@ class TestModelConversion(googletest.TestCase):
|
|
137
136
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
138
137
|
|
139
138
|
@googletest.skipIf(
|
140
|
-
|
141
|
-
reason="tests with custom ops are not supported
|
139
|
+
ai_edge_torch.config.in_oss,
|
140
|
+
reason="tests with custom ops are not supported in oss",
|
142
141
|
)
|
143
142
|
def test_smollm(self):
|
144
143
|
config = smollm.get_fake_model_config()
|
@@ -146,8 +145,8 @@ class TestModelConversion(googletest.TestCase):
|
|
146
145
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
147
146
|
|
148
147
|
@googletest.skipIf(
|
149
|
-
|
150
|
-
reason="tests with custom ops are not supported
|
148
|
+
ai_edge_torch.config.in_oss,
|
149
|
+
reason="tests with custom ops are not supported in oss",
|
151
150
|
)
|
152
151
|
def test_openelm(self):
|
153
152
|
config = openelm.get_fake_model_config()
|
@@ -155,8 +154,8 @@ class TestModelConversion(googletest.TestCase):
|
|
155
154
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
156
155
|
|
157
156
|
@googletest.skipIf(
|
158
|
-
|
159
|
-
reason="tests with custom ops are not supported
|
157
|
+
ai_edge_torch.config.in_oss,
|
158
|
+
reason="tests with custom ops are not supported in oss",
|
160
159
|
)
|
161
160
|
def test_qwen(self):
|
162
161
|
config = qwen.get_fake_model_config()
|
@@ -164,8 +163,8 @@ class TestModelConversion(googletest.TestCase):
|
|
164
163
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
165
164
|
|
166
165
|
@googletest.skipIf(
|
167
|
-
|
168
|
-
reason="tests with custom ops are not supported
|
166
|
+
ai_edge_torch.config.in_oss,
|
167
|
+
reason="tests with custom ops are not supported in oss",
|
169
168
|
)
|
170
169
|
def test_amd_llama_135m(self):
|
171
170
|
config = amd_llama_135m.get_fake_model_config()
|
@@ -173,8 +172,8 @@ class TestModelConversion(googletest.TestCase):
|
|
173
172
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
174
173
|
|
175
174
|
@googletest.skipIf(
|
176
|
-
|
177
|
-
reason="tests with custom ops are not supported
|
175
|
+
ai_edge_torch.config.in_oss,
|
176
|
+
reason="tests with custom ops are not supported in oss",
|
178
177
|
)
|
179
178
|
def disabled_test_paligemma(self):
|
180
179
|
config = paligemma.get_fake_model_config()
|
@@ -222,8 +221,8 @@ class TestModelConversion(googletest.TestCase):
|
|
222
221
|
)
|
223
222
|
|
224
223
|
@googletest.skipIf(
|
225
|
-
|
226
|
-
reason="tests with custom ops are not supported
|
224
|
+
ai_edge_torch.config.in_oss,
|
225
|
+
reason="tests with custom ops are not supported in oss",
|
227
226
|
)
|
228
227
|
def test_stable_diffusion_clip(self):
|
229
228
|
config = sd_clip.get_fake_model_config()
|
@@ -254,8 +253,8 @@ class TestModelConversion(googletest.TestCase):
|
|
254
253
|
)
|
255
254
|
|
256
255
|
@googletest.skipIf(
|
257
|
-
|
258
|
-
reason="tests with custom ops are not supported
|
256
|
+
ai_edge_torch.config.in_oss,
|
257
|
+
reason="tests with custom ops are not supported in oss",
|
259
258
|
)
|
260
259
|
def test_stable_diffusion_diffusion(self):
|
261
260
|
config = sd_diffusion.get_fake_model_config(2)
|
@@ -296,8 +295,8 @@ class TestModelConversion(googletest.TestCase):
|
|
296
295
|
)
|
297
296
|
|
298
297
|
@googletest.skipIf(
|
299
|
-
|
300
|
-
reason="tests with custom ops are not supported
|
298
|
+
ai_edge_torch.config.in_oss,
|
299
|
+
reason="tests with custom ops are not supported in oss",
|
301
300
|
)
|
302
301
|
def test_stable_diffusion_decoder(self):
|
303
302
|
config = sd_decoder.get_fake_model_config()
|
@@ -15,13 +15,15 @@
|
|
15
15
|
|
16
16
|
from typing import Any, Optional
|
17
17
|
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from ai_edge_torch._convert import signature
|
20
20
|
from ai_edge_torch.quantize import quant_config as qcfg
|
21
21
|
import torch
|
22
22
|
|
23
|
+
config = _config.config
|
24
|
+
|
23
25
|
# isort: off
|
24
|
-
if config.
|
26
|
+
if config.use_torch_xla:
|
25
27
|
from ai_edge_torch.lowertools import torch_xla_utils as utils
|
26
28
|
from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
|
27
29
|
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import re
|
17
17
|
from typing import Optional
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from absl.testing import absltest as googletest
|
20
20
|
|
21
|
+
config = _config.config
|
22
|
+
|
21
23
|
|
22
24
|
def _extract_backend_configs(mlir):
|
23
25
|
mlir = mlir.replace("\\22", '"')
|
@@ -38,7 +40,7 @@ def assert_string_count(
|
|
38
40
|
if odml_torch_attr_counter is None:
|
39
41
|
odml_torch_attr_counter = {}
|
40
42
|
|
41
|
-
if config.
|
43
|
+
if config.use_torch_xla:
|
42
44
|
for key in torch_xla_pattern_counter:
|
43
45
|
test_case.assertEqual(
|
44
46
|
mlir.count(key),
|
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
276
276
|
interior_padding if i == dim else 0 for i in range(rank)
|
277
277
|
],
|
278
278
|
)
|
279
|
-
|
280
|
-
|
279
|
+
|
280
|
+
slices = [
|
281
281
|
slice(start, end, step) if i == dim else slice(None, None, None)
|
282
282
|
for i in range(rank)
|
283
|
-
]
|
283
|
+
]
|
284
|
+
pred = np.ones(self.type.shape, dtype=np.bool_)
|
285
|
+
pred[np.index_exp[tuple(slices)]] = False
|
284
286
|
pred = stablehlo.constant(
|
285
287
|
ir.DenseElementsAttr.get(
|
286
288
|
np.packbits(pred, bitorder="little"),
|
@@ -232,7 +232,9 @@ def _aten_convolution(
|
|
232
232
|
|
233
233
|
if bias is not None:
|
234
234
|
# broadcast [C] to [NCHW]
|
235
|
-
broadcasted_bias = stablehlo.broadcast_in_dim(
|
235
|
+
broadcasted_bias = stablehlo.broadcast_in_dim(
|
236
|
+
output_type, bias, ir.DenseI64ArrayAttr.get([1])
|
237
|
+
)
|
236
238
|
res = stablehlo.add(
|
237
239
|
lhs=res,
|
238
240
|
rhs=broadcasted_bias,
|
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
|
|
20
20
|
from ai_edge_torch.odml_torch.lowerings import utils
|
21
21
|
from jax._src.lib.mlir import ir
|
22
22
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import numpy as np
|
23
24
|
import torch
|
24
25
|
|
25
26
|
|
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
|
|
66
67
|
normalized_rank = len(normalized_shape)
|
67
68
|
if weight is not None:
|
68
69
|
weight = stablehlo.broadcast_in_dim(
|
69
|
-
data_type,
|
70
|
+
data_type,
|
71
|
+
weight,
|
72
|
+
ir.DenseI64ArrayAttr.get(
|
73
|
+
list(range(data_rank - normalized_rank, data_rank))
|
74
|
+
),
|
70
75
|
)
|
71
76
|
output = stablehlo.multiply(weight, output)
|
72
77
|
if bias is not None:
|
73
78
|
bias = stablehlo.broadcast_in_dim(
|
74
|
-
data_type,
|
79
|
+
data_type,
|
80
|
+
bias,
|
81
|
+
ir.DenseI64ArrayAttr.get(
|
82
|
+
list(range(data_rank - normalized_rank, data_rank))
|
83
|
+
),
|
75
84
|
)
|
76
85
|
output = stablehlo.add(bias, output)
|
77
86
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
"""Lowerings for PT2E torch.ops.quantized_decomposed ops."""
|
16
|
-
from typing import Union, cast
|
16
|
+
from typing import Optional, Union, cast
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch.lowerings import context
|
19
19
|
from ai_edge_torch.odml_torch.lowerings import utils
|
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
|
|
30
30
|
|
31
31
|
|
32
32
|
def _uniform_quantized_type(
|
33
|
-
stored_type: str
|
34
|
-
expressed_type: str
|
33
|
+
stored_type: Union[str, ir.Type],
|
34
|
+
expressed_type: Union[str, ir.Type],
|
35
35
|
*,
|
36
|
-
scale=float
|
37
|
-
zero_point=float
|
38
|
-
storage_type_min: int
|
39
|
-
storage_type_max: int
|
40
|
-
channel_axis: int
|
41
|
-
channel_axis_size: int
|
36
|
+
scale=Union[float, list[float], tuple[float]],
|
37
|
+
zero_point=Union[float, list[float], tuple[float]],
|
38
|
+
storage_type_min: Optional[int] = None,
|
39
|
+
storage_type_max: Optional[int] = None,
|
40
|
+
channel_axis: Optional[int] = None,
|
41
|
+
channel_axis_size: Optional[int] = None,
|
42
42
|
):
|
43
43
|
"""Polyfill for quant.UniformQuantizedType."""
|
44
44
|
if storage_type_min and storage_type_max:
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241214
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
|
|
11
11
|
Classifier: License :: OSI Approved :: Apache Software License
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
13
13
|
Classifier: Programming Language :: Python :: 3 :: Only
|
14
|
-
Classifier: Programming Language :: Python :: 3.9
|
15
14
|
Classifier: Programming Language :: Python :: 3.10
|
16
15
|
Classifier: Programming Language :: Python :: 3.11
|
17
16
|
Classifier: Topic :: Scientific/Engineering
|
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
19
|
Classifier: Topic :: Software Development
|
21
20
|
Classifier: Topic :: Software Development :: Libraries
|
22
21
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23
|
-
Requires-Python: >=3.
|
22
|
+
Requires-Python: >=3.10
|
24
23
|
Description-Content-Type: text/markdown
|
25
24
|
License-File: LICENSE
|
26
25
|
Requires-Dist: numpy
|
@@ -28,10 +27,13 @@ Requires-Dist: scipy
|
|
28
27
|
Requires-Dist: safetensors
|
29
28
|
Requires-Dist: tabulate
|
30
29
|
Requires-Dist: torch>=2.4.0
|
31
|
-
Requires-Dist: torch-xla>=2.4.0
|
32
30
|
Requires-Dist: tf-nightly>=2.19.0.dev20241201
|
33
31
|
Requires-Dist: ai-edge-litert-nightly
|
34
32
|
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: jax
|
34
|
+
Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
|
35
|
+
Provides-Extra: torch-xla
|
36
|
+
Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
|
35
37
|
|
36
38
|
Library that supports converting PyTorch models into a .tflite format, which can
|
37
39
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
@@ -1,9 +1,9 @@
|
|
1
|
-
ai_edge_torch/__init__.py,sha256=
|
2
|
-
ai_edge_torch/
|
1
|
+
ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
|
2
|
+
ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=iCH8lnlOrtbGwvxnT3knpY_keeu2UnrJ_ZXNK2LSvf4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=gK9QJuLbpjXt0l6tVnzl9Miq6GLkJR-hB67i3VE13Og,17224
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -119,7 +119,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoT
|
|
119
119
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
120
120
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
121
121
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
122
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
122
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
123
123
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
124
124
|
ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
|
125
125
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
@@ -139,8 +139,8 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
139
139
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
|
140
140
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
|
141
141
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
142
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
143
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
142
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
143
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
|
144
144
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
145
145
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
146
146
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -159,12 +159,11 @@ ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge
|
|
159
159
|
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
|
160
160
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
161
161
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
162
|
-
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
|
163
162
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
164
|
-
ai_edge_torch/lowertools/_shim.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/_shim.py,sha256=xJIHDSWNoF4PkkT0JkjeJxgguQ9JGEwooJf9xZNkVRU,3058
|
165
164
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
166
165
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
|
167
|
-
ai_edge_torch/lowertools/test_utils.py,sha256=
|
166
|
+
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
168
167
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
|
169
168
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
170
169
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
@@ -183,12 +182,12 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
|
|
183
182
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
184
183
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
185
184
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
|
186
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
185
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
|
187
186
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
188
|
-
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=
|
187
|
+
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
189
188
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
190
|
-
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=
|
191
|
-
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=
|
189
|
+
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
190
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=GEs83mtEjh8GOW_OATI_ur11VKujrOL2xdZeZ0l1HtM,6100
|
192
191
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
193
192
|
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
194
193
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
@@ -201,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
201
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
202
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
203
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/METADATA,sha256=fUbq26zB0WUU1l6eUud8vq3Nm3KSIhox74pzFSFTmoM,1966
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/RECORD,,
|
ai_edge_torch/config.py
DELETED
@@ -1,27 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Provides a configuration for the AI Edge Torch library."""
|
17
|
-
|
18
|
-
import dataclasses
|
19
|
-
import os
|
20
|
-
|
21
|
-
|
22
|
-
@dataclasses.dataclass
|
23
|
-
class Config:
|
24
|
-
use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
|
25
|
-
"1",
|
26
|
-
"true",
|
27
|
-
)
|
@@ -1,283 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Tests for StableHLOCompositeBuilder."""
|
16
|
-
|
17
|
-
import math
|
18
|
-
|
19
|
-
from ai_edge_torch import config
|
20
|
-
from ai_edge_torch import lowertools
|
21
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
22
|
-
import torch
|
23
|
-
import torch.nn.functional as F
|
24
|
-
|
25
|
-
from absl.testing import absltest as googletest
|
26
|
-
|
27
|
-
|
28
|
-
def _export_stablehlo_mlir(model, args):
|
29
|
-
ep = torch.export.export(model, args)
|
30
|
-
return lowertools.exported_program_to_mlir_text(ep)
|
31
|
-
|
32
|
-
|
33
|
-
@googletest.skipIf(
|
34
|
-
not config.Config.use_torch_xla,
|
35
|
-
reason="The odml_torch counter part is in odml_torch.",
|
36
|
-
)
|
37
|
-
class TestStableHLOCompositeBuilder(googletest.TestCase):
|
38
|
-
|
39
|
-
def test_build_composite(self):
|
40
|
-
class SampleModel(torch.nn.Module):
|
41
|
-
|
42
|
-
def forward(self, x):
|
43
|
-
builder = StableHLOCompositeBuilder(name="test.plus_two")
|
44
|
-
y = x + 1
|
45
|
-
y = builder.mark_inputs(y)
|
46
|
-
z = y + 2
|
47
|
-
z = builder.mark_outputs(z)
|
48
|
-
return z
|
49
|
-
|
50
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
51
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1)
|
52
|
-
|
53
|
-
def test_build_multiple_composites(self):
|
54
|
-
class SampleModel(torch.nn.Module):
|
55
|
-
|
56
|
-
def plus_one(self, x: torch.Tensor):
|
57
|
-
builder = StableHLOCompositeBuilder("test.plus_one")
|
58
|
-
x = builder.mark_inputs(x)
|
59
|
-
y = x + 1
|
60
|
-
y = builder.mark_outputs(y)
|
61
|
-
return y
|
62
|
-
|
63
|
-
def plus_two(self, x: torch.Tensor):
|
64
|
-
builder = StableHLOCompositeBuilder("test.plus_two")
|
65
|
-
x = builder.mark_inputs(x)
|
66
|
-
y = x + 2
|
67
|
-
y = builder.mark_outputs(y)
|
68
|
-
return y
|
69
|
-
|
70
|
-
def forward(self, x):
|
71
|
-
x = self.plus_two(x)
|
72
|
-
x = x + 3
|
73
|
-
x = self.plus_one(x)
|
74
|
-
x = x + 4
|
75
|
-
x = self.plus_two(x)
|
76
|
-
return x
|
77
|
-
|
78
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
79
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1)
|
80
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2)
|
81
|
-
|
82
|
-
def test_build_composite_with_attr(self):
|
83
|
-
class SampleModel(torch.nn.Module):
|
84
|
-
|
85
|
-
def __init__(self):
|
86
|
-
super().__init__()
|
87
|
-
|
88
|
-
def log_softmax(self, x: torch.Tensor, dim: int):
|
89
|
-
builder = StableHLOCompositeBuilder(
|
90
|
-
name="test.log_softmax", attr={"dim": dim}
|
91
|
-
)
|
92
|
-
x = builder.mark_inputs(x)
|
93
|
-
y = torch.nn.functional.log_softmax(x, dim=dim)
|
94
|
-
y = builder.mark_outputs(y)
|
95
|
-
return y
|
96
|
-
|
97
|
-
def forward(self, x):
|
98
|
-
x = x + 1
|
99
|
-
x = self.log_softmax(x, 0)
|
100
|
-
x = self.log_softmax(x, 1)
|
101
|
-
return x
|
102
|
-
|
103
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
104
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2)
|
105
|
-
self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1)
|
106
|
-
self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1)
|
107
|
-
|
108
|
-
def test_build_composite_with_mix_type_attrs(self):
|
109
|
-
class SampleModel(torch.nn.Module):
|
110
|
-
|
111
|
-
def __init__(self):
|
112
|
-
super().__init__()
|
113
|
-
|
114
|
-
def log_softmax(self, x: torch.Tensor, dim: int):
|
115
|
-
builder = StableHLOCompositeBuilder(
|
116
|
-
name="test.log_softmax",
|
117
|
-
attr={
|
118
|
-
"dim": dim,
|
119
|
-
"source": "torch.nn",
|
120
|
-
"version": 1.0,
|
121
|
-
},
|
122
|
-
)
|
123
|
-
x = builder.mark_inputs(x)
|
124
|
-
y = torch.nn.functional.log_softmax(x, dim=dim)
|
125
|
-
y = builder.mark_outputs(y)
|
126
|
-
return y
|
127
|
-
|
128
|
-
def forward(self, x):
|
129
|
-
x = x + 1
|
130
|
-
x = self.log_softmax(x, 0)
|
131
|
-
return x
|
132
|
-
|
133
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
134
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
|
135
|
-
self.assertEqual(
|
136
|
-
mlir.count(
|
137
|
-
'composite_attributes = {dim = 0 : i64, source = "torch.nn",'
|
138
|
-
" version = 1.000000e+00 : f32}"
|
139
|
-
),
|
140
|
-
1,
|
141
|
-
)
|
142
|
-
|
143
|
-
def test_sdpa_composite(self):
|
144
|
-
class SDPAModel(torch.nn.Module):
|
145
|
-
|
146
|
-
def scaled_dot_product_attention(
|
147
|
-
self,
|
148
|
-
q: torch.Tensor,
|
149
|
-
k: torch.Tensor,
|
150
|
-
v: torch.Tensor,
|
151
|
-
head_size: int,
|
152
|
-
mask: torch.Tensor,
|
153
|
-
):
|
154
|
-
builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention")
|
155
|
-
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
156
|
-
|
157
|
-
scale = 1.0 / math.sqrt(head_size)
|
158
|
-
|
159
|
-
q = q.transpose(1, 2)
|
160
|
-
k = k.transpose(1, 2)
|
161
|
-
v = v.transpose(1, 2)
|
162
|
-
y = F.scaled_dot_product_attention(
|
163
|
-
q,
|
164
|
-
k,
|
165
|
-
v,
|
166
|
-
attn_mask=mask,
|
167
|
-
dropout_p=0.0,
|
168
|
-
is_causal=mask is None,
|
169
|
-
scale=scale,
|
170
|
-
)
|
171
|
-
result = y.transpose(1, 2)
|
172
|
-
result = builder.mark_outputs(result)
|
173
|
-
return result
|
174
|
-
|
175
|
-
def forward(self, q, k, v, mask):
|
176
|
-
x = self.scaled_dot_product_attention(
|
177
|
-
q,
|
178
|
-
k,
|
179
|
-
v,
|
180
|
-
8,
|
181
|
-
mask,
|
182
|
-
)
|
183
|
-
return x
|
184
|
-
|
185
|
-
query = torch.rand(1, 1, 32, 4)
|
186
|
-
key = torch.rand(1, 500, 1, 4)
|
187
|
-
value = torch.rand(1, 500, 1, 4)
|
188
|
-
mask = torch.rand(1, 1, 1, 500)
|
189
|
-
|
190
|
-
mlir = _export_stablehlo_mlir(
|
191
|
-
SDPAModel().eval(),
|
192
|
-
(query, key, value, mask),
|
193
|
-
)
|
194
|
-
self.assertEqual(
|
195
|
-
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1
|
196
|
-
)
|
197
|
-
|
198
|
-
def test_sdpa_composite_with_attr(self):
|
199
|
-
class SDPAModel(torch.nn.Module):
|
200
|
-
|
201
|
-
def scaled_dot_product_attention(
|
202
|
-
self,
|
203
|
-
q: torch.Tensor,
|
204
|
-
k: torch.Tensor,
|
205
|
-
v: torch.Tensor,
|
206
|
-
head_size: int,
|
207
|
-
include_captanh: bool,
|
208
|
-
):
|
209
|
-
builder = StableHLOCompositeBuilder(
|
210
|
-
name="test.scaled_dot_product_attention",
|
211
|
-
attr={"include_captanh": include_captanh},
|
212
|
-
)
|
213
|
-
q, k, v = builder.mark_inputs(q, k, v)
|
214
|
-
|
215
|
-
scale = 1.0 / math.sqrt(head_size)
|
216
|
-
|
217
|
-
q = q.transpose(1, 2)
|
218
|
-
k = k.transpose(1, 2)
|
219
|
-
v = v.transpose(1, 2)
|
220
|
-
y = F.scaled_dot_product_attention(
|
221
|
-
q,
|
222
|
-
k,
|
223
|
-
v,
|
224
|
-
attn_mask=None,
|
225
|
-
dropout_p=0.0,
|
226
|
-
is_causal=True,
|
227
|
-
scale=scale,
|
228
|
-
)
|
229
|
-
result = y.transpose(1, 2)
|
230
|
-
result = builder.mark_outputs(result)
|
231
|
-
return result
|
232
|
-
|
233
|
-
def forward(self, q, k, v):
|
234
|
-
x = self.scaled_dot_product_attention(q, k, v, 8, True)
|
235
|
-
y = self.scaled_dot_product_attention(q, k, v, 8, False)
|
236
|
-
return x + y
|
237
|
-
|
238
|
-
query = torch.rand(1, 1, 32, 4)
|
239
|
-
key = torch.rand(1, 500, 1, 4)
|
240
|
-
value = torch.rand(1, 500, 1, 4)
|
241
|
-
mlir = _export_stablehlo_mlir(
|
242
|
-
SDPAModel().eval(),
|
243
|
-
(query, key, value),
|
244
|
-
)
|
245
|
-
self.assertEqual(
|
246
|
-
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
|
247
|
-
)
|
248
|
-
self.assertEqual(
|
249
|
-
mlir.count("composite_attributes = {include_captanh = true}"), 1
|
250
|
-
)
|
251
|
-
self.assertEqual(
|
252
|
-
mlir.count("composite_attributes = {include_captanh = false}"), 1
|
253
|
-
)
|
254
|
-
|
255
|
-
def test_build_composite_with_multiple_inputs_outputs(self):
|
256
|
-
class SampleModel(torch.nn.Module):
|
257
|
-
|
258
|
-
def mimo_sample(self, a, b, c):
|
259
|
-
builder = StableHLOCompositeBuilder(name="test.mimo_sample")
|
260
|
-
|
261
|
-
a, b, c = builder.mark_inputs(a, b, c)
|
262
|
-
x = a + b + c
|
263
|
-
y = (a - b) * x
|
264
|
-
z = (c + 1.0) * a
|
265
|
-
x, y, z = builder.mark_outputs(x, y, z)
|
266
|
-
|
267
|
-
result = x + y * z
|
268
|
-
return result
|
269
|
-
|
270
|
-
def forward(self, a, b, c):
|
271
|
-
x = self.mimo_sample(a, b, c)
|
272
|
-
x = self.mimo_sample(a, b, x)
|
273
|
-
x = self.mimo_sample(x, x, c)
|
274
|
-
return x
|
275
|
-
|
276
|
-
mlir = _export_stablehlo_mlir(
|
277
|
-
SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2))
|
278
|
-
)
|
279
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3)
|
280
|
-
|
281
|
-
|
282
|
-
if __name__ == "__main__":
|
283
|
-
googletest.main()
|
File without changes
|
File without changes
|