ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.3.0.dev20240809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.3.0.dev20240809.dist-info/RECORD +141 -0
- ai_edge_torch/convert/conversion_utils.py +0 -439
- ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/top_level.txt +0 -0
|
@@ -12,24 +12,25 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Tests for ai_edge_torch.convert."""
|
|
15
16
|
|
|
16
|
-
|
|
17
|
-
from dataclasses import dataclass
|
|
17
|
+
import dataclasses
|
|
18
18
|
import os
|
|
19
|
-
import tempfile
|
|
20
19
|
from typing import Tuple
|
|
21
|
-
import unittest
|
|
22
20
|
|
|
23
21
|
import ai_edge_torch
|
|
24
|
-
from ai_edge_torch
|
|
22
|
+
from ai_edge_torch import config
|
|
23
|
+
from ai_edge_torch._convert import conversion_utils
|
|
25
24
|
from ai_edge_torch.testing import model_coverage
|
|
26
25
|
import numpy as np
|
|
27
26
|
import tensorflow as tf
|
|
28
27
|
import torch
|
|
29
28
|
import torchvision
|
|
30
29
|
|
|
30
|
+
from tensorflow.python.platform import googletest
|
|
31
|
+
|
|
31
32
|
|
|
32
|
-
@dataclass
|
|
33
|
+
@dataclasses.dataclass
|
|
33
34
|
class TestContainer1:
|
|
34
35
|
data_1: torch.Tensor
|
|
35
36
|
data_2: Tuple[torch.Tensor, torch.Tensor]
|
|
@@ -40,10 +41,11 @@ torch.export.register_dataclass(
|
|
|
40
41
|
)
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
class TestConvert(
|
|
44
|
+
class TestConvert(googletest.TestCase):
|
|
44
45
|
"""Tests conversion of various modules."""
|
|
45
46
|
|
|
46
47
|
def setUp(self):
|
|
48
|
+
super().setUp()
|
|
47
49
|
torch.manual_seed(0)
|
|
48
50
|
|
|
49
51
|
def test_convert_add(self):
|
|
@@ -66,8 +68,9 @@ class TestConvert(unittest.TestCase):
|
|
|
66
68
|
)
|
|
67
69
|
|
|
68
70
|
def test_convert_dot_add(self):
|
|
71
|
+
"""Tests conversion of a matrix multiplication followed by an add."""
|
|
72
|
+
|
|
69
73
|
class DotAdd(torch.nn.Module):
|
|
70
|
-
"""Tests conversion of a matrix multiplication followed by an add."""
|
|
71
74
|
|
|
72
75
|
def forward(self, a, b, c):
|
|
73
76
|
return a @ b + c
|
|
@@ -97,20 +100,21 @@ class TestConvert(unittest.TestCase):
|
|
|
97
100
|
"""Tests conversion of a model with more than 10 arguments."""
|
|
98
101
|
|
|
99
102
|
class AddChainWith11Args(torch.nn.Module):
|
|
103
|
+
"""A model with 11 arguments."""
|
|
100
104
|
|
|
101
105
|
def forward(
|
|
102
106
|
self,
|
|
103
|
-
arg0:
|
|
104
|
-
arg1:
|
|
105
|
-
arg2:
|
|
106
|
-
arg3:
|
|
107
|
-
arg4:
|
|
108
|
-
arg5:
|
|
109
|
-
arg6:
|
|
110
|
-
arg7:
|
|
111
|
-
arg8:
|
|
112
|
-
arg9:
|
|
113
|
-
arg10:
|
|
107
|
+
arg0: torch.Tensor,
|
|
108
|
+
arg1: torch.Tensor,
|
|
109
|
+
arg2: torch.Tensor,
|
|
110
|
+
arg3: torch.Tensor,
|
|
111
|
+
arg4: torch.Tensor,
|
|
112
|
+
arg5: torch.Tensor,
|
|
113
|
+
arg6: torch.Tensor,
|
|
114
|
+
arg7: torch.Tensor,
|
|
115
|
+
arg8: torch.Tensor,
|
|
116
|
+
arg9: torch.Tensor,
|
|
117
|
+
arg10: torch.Tensor,
|
|
114
118
|
):
|
|
115
119
|
add0 = torch.add(arg0, arg1)
|
|
116
120
|
add1 = torch.add(add0, arg2)
|
|
@@ -149,6 +153,7 @@ class TestConvert(unittest.TestCase):
|
|
|
149
153
|
"""Tests conversion of a model that returns multiple outputs."""
|
|
150
154
|
|
|
151
155
|
class BasicAddModelWithMultipleOutputs(torch.nn.Module):
|
|
156
|
+
"""A model that returns multiple outputs."""
|
|
152
157
|
|
|
153
158
|
def forward(self, arg0, arg1):
|
|
154
159
|
add0 = arg0 + arg1
|
|
@@ -172,6 +177,7 @@ class TestConvert(unittest.TestCase):
|
|
|
172
177
|
"""Tests conversion of a model that returns multiple outputs."""
|
|
173
178
|
|
|
174
179
|
class BasicAddModelWithMultipleOutputs(torch.nn.Module):
|
|
180
|
+
"""A model that returns multiple outputs."""
|
|
175
181
|
|
|
176
182
|
def forward(self, arg0, arg1):
|
|
177
183
|
add0 = arg0 + arg1
|
|
@@ -215,8 +221,8 @@ class TestConvert(unittest.TestCase):
|
|
|
215
221
|
)
|
|
216
222
|
self.assertTrue(result)
|
|
217
223
|
|
|
218
|
-
def
|
|
219
|
-
"""Tests if
|
|
224
|
+
def test_apply_tfl_converter_flags(self):
|
|
225
|
+
"""Tests if _apply_tfl_converter_flags correctly sets the values in a Converter object."""
|
|
220
226
|
|
|
221
227
|
class MockConverterInternalObject:
|
|
222
228
|
|
|
@@ -231,12 +237,12 @@ class TestConvert(unittest.TestCase):
|
|
|
231
237
|
|
|
232
238
|
mock_converter = MockConverter()
|
|
233
239
|
flags = {"key1": "new_value1", "key2": {"subkey2": "new_subvalue2"}}
|
|
234
|
-
|
|
240
|
+
conversion_utils.apply_tfl_converter_flags(mock_converter, flags)
|
|
235
241
|
|
|
236
242
|
self.assertTrue(flags["key1"], "new_value1")
|
|
237
243
|
self.assertTrue(flags["key2"]["subkey2"], "new_subvalue2")
|
|
238
244
|
|
|
239
|
-
def
|
|
245
|
+
def test_convert_add_converter_flags(self):
|
|
240
246
|
"""Tests conversion of an add module setting a tflite converter flag."""
|
|
241
247
|
|
|
242
248
|
class Add(torch.nn.Module):
|
|
@@ -250,21 +256,23 @@ class TestConvert(unittest.TestCase):
|
|
|
250
256
|
)
|
|
251
257
|
torch_module = Add().eval()
|
|
252
258
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
259
|
+
tmp_dir_path = self.create_tempdir()
|
|
260
|
+
ir_dump_path = os.path.join(
|
|
261
|
+
tmp_dir_path, "test_convert_add_converter_flags_mlir_dump"
|
|
262
|
+
)
|
|
263
|
+
ai_edge_torch.convert(
|
|
264
|
+
torch_module,
|
|
265
|
+
args,
|
|
266
|
+
_ai_edge_converter_flags={"ir_dump_dir": ir_dump_path},
|
|
267
|
+
)
|
|
268
|
+
self.assertTrue(os.path.isdir(ir_dump_path))
|
|
263
269
|
|
|
270
|
+
@googletest.skipIf(
|
|
271
|
+
not config.Config.use_torch_xla,
|
|
272
|
+
reason="Shape polymorphism is not yet support with odml_torch.",
|
|
273
|
+
)
|
|
264
274
|
def test_convert_model_with_dynamic_batch(self):
|
|
265
|
-
"""
|
|
266
|
-
Test converting a simple model with dynamic batch size.
|
|
267
|
-
"""
|
|
275
|
+
"""Test converting a simple model with dynamic batch size."""
|
|
268
276
|
|
|
269
277
|
class SampleModel(torch.nn.Module):
|
|
270
278
|
|
|
@@ -294,9 +302,7 @@ class TestConvert(unittest.TestCase):
|
|
|
294
302
|
)
|
|
295
303
|
|
|
296
304
|
def test_convert_model_with_kwargs(self):
|
|
297
|
-
"""
|
|
298
|
-
Test converting a simple model with sample_kwargs.
|
|
299
|
-
"""
|
|
305
|
+
"""Test converting a simple model with sample_kwargs."""
|
|
300
306
|
|
|
301
307
|
class SampleModel(torch.nn.Module):
|
|
302
308
|
|
|
@@ -315,9 +321,7 @@ class TestConvert(unittest.TestCase):
|
|
|
315
321
|
)
|
|
316
322
|
|
|
317
323
|
def test_convert_model_with_args_kwargs(self):
|
|
318
|
-
"""
|
|
319
|
-
Test converting a simple model with both sample_args and sample_kwargs.
|
|
320
|
-
"""
|
|
324
|
+
"""Test converting a simple model with both sample_args and sample_kwargs."""
|
|
321
325
|
|
|
322
326
|
class SampleModel(torch.nn.Module):
|
|
323
327
|
|
|
@@ -337,9 +341,7 @@ class TestConvert(unittest.TestCase):
|
|
|
337
341
|
)
|
|
338
342
|
|
|
339
343
|
def test_convert_model_with_args_nested_kwargs_1(self):
|
|
340
|
-
"""
|
|
341
|
-
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
342
|
-
"""
|
|
344
|
+
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
|
343
345
|
|
|
344
346
|
class SampleModel(torch.nn.Module):
|
|
345
347
|
|
|
@@ -366,9 +368,7 @@ class TestConvert(unittest.TestCase):
|
|
|
366
368
|
)
|
|
367
369
|
|
|
368
370
|
def test_convert_model_with_args_nested_kwargs_2(self):
|
|
369
|
-
"""
|
|
370
|
-
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
371
|
-
"""
|
|
371
|
+
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
|
372
372
|
|
|
373
373
|
class SampleModel(torch.nn.Module):
|
|
374
374
|
|
|
@@ -395,9 +395,7 @@ class TestConvert(unittest.TestCase):
|
|
|
395
395
|
)
|
|
396
396
|
|
|
397
397
|
def test_convert_model_with_args_nested_kwargs_3(self):
|
|
398
|
-
"""
|
|
399
|
-
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
400
|
-
"""
|
|
398
|
+
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
|
401
399
|
|
|
402
400
|
class SampleModel(torch.nn.Module):
|
|
403
401
|
|
|
@@ -437,4 +435,4 @@ class TestConvert(unittest.TestCase):
|
|
|
437
435
|
|
|
438
436
|
|
|
439
437
|
if __name__ == "__main__":
|
|
440
|
-
|
|
438
|
+
googletest.main()
|
|
@@ -12,18 +12,21 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Tests conversion modules that are meant to be wrapped as composites."""
|
|
15
16
|
|
|
16
|
-
|
|
17
|
-
from typing import Callable
|
|
18
|
-
import unittest
|
|
17
|
+
from collections.abc import Callable
|
|
19
18
|
|
|
20
19
|
import ai_edge_torch
|
|
21
20
|
from ai_edge_torch.testing import model_coverage
|
|
22
21
|
import parameterized
|
|
23
22
|
import torch
|
|
24
23
|
|
|
24
|
+
from tensorflow.python.platform import googletest
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _func_to_torch_module(func: Callable[..., torch.Tensor]):
|
|
28
|
+
"""Wraps a function into a torch module."""
|
|
25
29
|
|
|
26
|
-
def _func_to_torch_module(func: Callable):
|
|
27
30
|
class TestModule(torch.nn.Module):
|
|
28
31
|
|
|
29
32
|
def __init__(self, func):
|
|
@@ -36,7 +39,7 @@ def _func_to_torch_module(func: Callable):
|
|
|
36
39
|
return TestModule(func).eval()
|
|
37
40
|
|
|
38
41
|
|
|
39
|
-
class TestConvertComposites(
|
|
42
|
+
class TestConvertComposites(googletest.TestCase):
|
|
40
43
|
"""Tests conversion modules that are meant to be wrapped as composites."""
|
|
41
44
|
|
|
42
45
|
def test_convert_hardswish(self):
|
|
@@ -51,7 +54,8 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
51
54
|
)
|
|
52
55
|
|
|
53
56
|
@parameterized.parameterized.expand([
|
|
54
|
-
# input_size, kernel_size, stride, padding, ceil_mode,
|
|
57
|
+
# (input_size, kernel_size, stride, padding, ceil_mode,
|
|
58
|
+
# count_include_pad, divisor_override)
|
|
55
59
|
# no padding, stride = 1
|
|
56
60
|
([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
|
|
57
61
|
# add stride
|
|
@@ -64,6 +68,8 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
64
68
|
([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
|
|
65
69
|
# add both stride and padding
|
|
66
70
|
([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
|
|
71
|
+
# padding set to one number
|
|
72
|
+
([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
|
|
67
73
|
# count_include_pad = False
|
|
68
74
|
([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
|
|
69
75
|
# ceil_mode = True
|
|
@@ -72,8 +78,6 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
72
78
|
([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
|
|
73
79
|
# set divisor_override
|
|
74
80
|
([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
|
|
75
|
-
# padding set to one number
|
|
76
|
-
([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
|
|
77
81
|
])
|
|
78
82
|
def test_convert_avg_pool2d(self, input_size, *args):
|
|
79
83
|
"""Tests conversion of a module containing an avg_pool2d aten."""
|
|
@@ -111,7 +115,7 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
111
115
|
def test_convert_upsample_bilinear_functional(self, input_size, kwargs):
|
|
112
116
|
"""Tests conversion of a torch.nn.functional.upsample module."""
|
|
113
117
|
torch_module = _func_to_torch_module(
|
|
114
|
-
lambda input_tensor: torch.nn.functional.upsample(
|
|
118
|
+
lambda input_tensor: torch.nn.functional.upsample( # pylint: disable=unnecessary-lambda
|
|
115
119
|
input_tensor, **kwargs
|
|
116
120
|
)
|
|
117
121
|
)
|
|
@@ -146,7 +150,7 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
146
150
|
def test_convert_upsample_bilinear(self, input_size, kwargs):
|
|
147
151
|
"""Tests conversion of a torch.nn.Upsample module."""
|
|
148
152
|
torch_module = _func_to_torch_module(
|
|
149
|
-
lambda input_tensor: torch.nn.Upsample(**kwargs)(input_tensor)
|
|
153
|
+
lambda input_tensor: torch.nn.Upsample(**kwargs)(input_tensor) # pylint: disable=unnecessary-lambda
|
|
150
154
|
)
|
|
151
155
|
tracing_args = (torch.randn(*input_size),)
|
|
152
156
|
edge_model = ai_edge_torch.convert(torch_module, tracing_args)
|
|
@@ -179,7 +183,7 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
179
183
|
def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
|
|
180
184
|
"""Tests conversion of a torch.nn.functional.interpolate module."""
|
|
181
185
|
torch_module = _func_to_torch_module(
|
|
182
|
-
lambda input_tensor: torch.nn.functional.interpolate(
|
|
186
|
+
lambda input_tensor: torch.nn.functional.interpolate( # pylint: disable=unnecessary-lambda
|
|
183
187
|
input_tensor, **kwargs
|
|
184
188
|
)
|
|
185
189
|
)
|
|
@@ -227,4 +231,4 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
227
231
|
|
|
228
232
|
|
|
229
233
|
if __name__ == '__main__':
|
|
230
|
-
|
|
234
|
+
googletest.main()
|
|
@@ -12,19 +12,21 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
import unittest
|
|
15
|
+
"""Tests for multi-signature conversion."""
|
|
17
16
|
|
|
18
17
|
import ai_edge_torch
|
|
19
18
|
from ai_edge_torch.testing import model_coverage
|
|
20
19
|
import torch
|
|
21
20
|
import torchvision
|
|
22
21
|
|
|
22
|
+
from tensorflow.python.platform import googletest
|
|
23
|
+
|
|
23
24
|
|
|
24
|
-
class TestConvertMultiSignature(
|
|
25
|
+
class TestConvertMultiSignature(googletest.TestCase):
|
|
25
26
|
"""Tests conversion of various modules through multi-signature conversion."""
|
|
26
27
|
|
|
27
28
|
def setUp(self):
|
|
29
|
+
super().setUp()
|
|
28
30
|
torch.manual_seed(0)
|
|
29
31
|
|
|
30
32
|
def test_convert_mobilenet_v2_with_default(self):
|
|
@@ -144,4 +146,4 @@ class TestConvertMultiSignature(unittest.TestCase):
|
|
|
144
146
|
|
|
145
147
|
|
|
146
148
|
if __name__ == "__main__":
|
|
147
|
-
|
|
149
|
+
googletest.main()
|
|
@@ -12,12 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
import unittest
|
|
15
|
+
"""Tests for to_channel_last_io API and module wrapper."""
|
|
17
16
|
|
|
18
17
|
import ai_edge_torch
|
|
19
18
|
import torch
|
|
20
19
|
|
|
20
|
+
from tensorflow.python.platform import googletest
|
|
21
|
+
|
|
21
22
|
|
|
22
23
|
class Identity(torch.nn.Module):
|
|
23
24
|
|
|
@@ -25,7 +26,7 @@ class Identity(torch.nn.Module):
|
|
|
25
26
|
return x
|
|
26
27
|
|
|
27
28
|
|
|
28
|
-
class TestToChannelLastIO(
|
|
29
|
+
class TestToChannelLastIO(googletest.TestCase):
|
|
29
30
|
"""Tests to_channel_last_io API and module wrapper."""
|
|
30
31
|
|
|
31
32
|
def test_no_transformations(self):
|
|
@@ -92,4 +93,4 @@ class TestToChannelLastIO(unittest.TestCase):
|
|
|
92
93
|
|
|
93
94
|
|
|
94
95
|
if __name__ == "__main__":
|
|
95
|
-
|
|
96
|
+
googletest.main()
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Transforms the input and output of a module to channel last layout."""
|
|
15
16
|
|
|
16
17
|
from typing import Optional
|
|
17
18
|
|
|
@@ -82,8 +83,10 @@ def to_channel_last_io(
|
|
|
82
83
|
(N, C, ...) to channel last (N, ..., C).
|
|
83
84
|
outputs (list[int]): Transform outputs with indices in the list from channel
|
|
84
85
|
first (N, C, ...) to channel last (N, ..., C).
|
|
86
|
+
|
|
85
87
|
Returns:
|
|
86
|
-
The wrapped nn.Module with additional layout transposes after inputs and/or
|
|
88
|
+
The wrapped nn.Module with additional layout transposes after inputs and/or
|
|
89
|
+
before
|
|
87
90
|
outputs.
|
|
88
91
|
"""
|
|
89
92
|
return ChannelLastIOWrapper(module, args=args, outputs=outputs)
|
ai_edge_torch/config.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Provides a configuration for the AI Edge Torch library."""
|
|
17
|
+
|
|
18
|
+
import dataclasses
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclasses.dataclass
|
|
23
|
+
class Config:
|
|
24
|
+
use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "True") == "True"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from absl import flags
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def pytest_configure(config):
|
|
20
|
+
flags.FLAGS.mark_as_parsed()
|
ai_edge_torch/debug/culprit.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Culprit finder for AI Edge Torch conversion."""
|
|
15
16
|
|
|
16
17
|
import contextlib
|
|
17
18
|
import copy
|
|
@@ -20,14 +21,13 @@ import functools
|
|
|
20
21
|
import io
|
|
21
22
|
import operator
|
|
22
23
|
import os
|
|
23
|
-
import sys
|
|
24
24
|
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
25
25
|
|
|
26
26
|
import ai_edge_torch
|
|
27
27
|
from ai_edge_torch.debug import utils
|
|
28
|
-
from functorch.compile import minifier as fx_minifier
|
|
29
28
|
import torch
|
|
30
29
|
from torch._functorch import aot_autograd
|
|
30
|
+
from torch._functorch.fx_minifier import minifier as fx_minifier
|
|
31
31
|
import torch.utils._pytree as pytree
|
|
32
32
|
|
|
33
33
|
_torch_float_dtypes = {
|
|
@@ -116,7 +116,7 @@ class Culprit(SearchResult):
|
|
|
116
116
|
print_output: bool - If true, prints the code to stdout. Otherwise returns
|
|
117
117
|
the code in a str.
|
|
118
118
|
"""
|
|
119
|
-
# TODO
|
|
119
|
+
# TODO: b/321263453 - Support Python code gen with sample arg tensor values.
|
|
120
120
|
random_inputs = True
|
|
121
121
|
|
|
122
122
|
graph_module_code = self.graph_module.print_readable(
|
|
@@ -152,6 +152,7 @@ class Culprit(SearchResult):
|
|
|
152
152
|
|
|
153
153
|
def print_code(self, print_output=True):
|
|
154
154
|
"""Print the Python code for culprit graph module, sample args, and AI
|
|
155
|
+
|
|
155
156
|
Edge Torch conversion that will fail with the error.
|
|
156
157
|
|
|
157
158
|
Args:
|
|
@@ -188,8 +189,8 @@ class Culprit(SearchResult):
|
|
|
188
189
|
|
|
189
190
|
|
|
190
191
|
def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
|
|
191
|
-
"""
|
|
192
|
-
|
|
192
|
+
"""This function turns all operator getitem nodes in ExportedProgram FX graph to
|
|
193
|
+
|
|
193
194
|
new nodes composed of "computation + getitem". The normalization duplicates
|
|
194
195
|
some computations in the graph but would make the graph more friendly for
|
|
195
196
|
partitioning in FX minifier.
|
|
@@ -367,19 +368,18 @@ def _search_model(
|
|
|
367
368
|
max_granularity: Optional[int] = None,
|
|
368
369
|
enable_fx_minifier_logging: bool = False,
|
|
369
370
|
) -> Generator[SearchResult, None, None]:
|
|
370
|
-
"""Finds subgraphs in the torch model that
|
|
371
|
+
"""Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
|
|
371
372
|
|
|
372
373
|
Args:
|
|
373
|
-
predicate_f: a predicate function the users specify.
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
return False otherwise.
|
|
374
|
+
predicate_f: a predicate function the users specify. It takes a FX
|
|
375
|
+
(sub)graph and the inputs to this graph, return True if the graph
|
|
376
|
+
satisfies the predicate, return False otherwise.
|
|
377
377
|
model: model in which to search subgraph.
|
|
378
|
-
export_args: A set of args to trace the model with,
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
378
|
+
export_args: A set of args to trace the model with, i.e. model(*args) must
|
|
379
|
+
run. max_granularity - FX minifier arg. The maximum granularity (number of
|
|
380
|
+
nodes) in the returned ATen FX subgraph of the culprit.
|
|
381
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to
|
|
382
|
+
log the progress.
|
|
383
383
|
"""
|
|
384
384
|
|
|
385
385
|
if isinstance(model, torch.nn.Module):
|
|
@@ -469,13 +469,13 @@ def find_culprits(
|
|
|
469
469
|
|
|
470
470
|
Args:
|
|
471
471
|
torch_model: model to export and save
|
|
472
|
-
args: A set of args to trace the model with, i.e.
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
472
|
+
args: A set of args to trace the model with, i.e. torch_model(*args) must
|
|
473
|
+
run max_granularity - FX minifier arg. The maximum granularity (number of
|
|
474
|
+
nodes) in the returned ATen FX subgraph of the culprit.
|
|
475
|
+
runtime_errors: If true, find culprits for Python runtime errors with
|
|
476
|
+
converted model.
|
|
477
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to
|
|
478
|
+
log the progress.
|
|
479
479
|
"""
|
|
480
480
|
|
|
481
481
|
fx_minifier_checker = functools.partial(
|
|
@@ -17,11 +17,12 @@
|
|
|
17
17
|
import ast
|
|
18
18
|
import io
|
|
19
19
|
import sys
|
|
20
|
-
import unittest
|
|
21
20
|
|
|
22
21
|
from ai_edge_torch.debug import find_culprits
|
|
23
22
|
import torch
|
|
24
23
|
|
|
24
|
+
from tensorflow.python.platform import googletest
|
|
25
|
+
|
|
25
26
|
_test_culprit_lib = torch.library.Library("test_culprit", "DEF")
|
|
26
27
|
|
|
27
28
|
_test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
|
|
@@ -49,7 +50,7 @@ class BadModel(torch.nn.Module):
|
|
|
49
50
|
return x
|
|
50
51
|
|
|
51
52
|
|
|
52
|
-
class TestCulprit(
|
|
53
|
+
class TestCulprit(googletest.TestCase):
|
|
53
54
|
|
|
54
55
|
def test_find_culprits(self):
|
|
55
56
|
model = BadModel().eval()
|
|
@@ -131,4 +132,4 @@ class TestCulprit(unittest.TestCase):
|
|
|
131
132
|
|
|
132
133
|
|
|
133
134
|
if __name__ == "__main__":
|
|
134
|
-
|
|
135
|
+
googletest.main()
|
|
@@ -12,15 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
import unittest
|
|
15
|
+
"""Tests for search_model."""
|
|
18
16
|
|
|
19
17
|
from ai_edge_torch.debug import _search_model
|
|
20
18
|
import torch
|
|
21
19
|
|
|
20
|
+
from tensorflow.python.platform import googletest
|
|
21
|
+
|
|
22
22
|
|
|
23
|
-
class TestSearchModel(
|
|
23
|
+
class TestSearchModel(googletest.TestCase):
|
|
24
24
|
|
|
25
25
|
def test_search_model_with_ops(self):
|
|
26
26
|
class MultipleOpsModel(torch.nn.Module):
|
|
@@ -48,4 +48,4 @@ class TestSearchModel(unittest.TestCase):
|
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
if __name__ == "__main__":
|
|
51
|
-
|
|
51
|
+
googletest.main()
|
ai_edge_torch/debug/utils.py
CHANGED
|
@@ -12,12 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Utils for debugging."""
|
|
16
|
+
|
|
15
17
|
import contextlib
|
|
16
18
|
import sys
|
|
17
19
|
|
|
18
20
|
import torch
|
|
19
|
-
from torch.export.graph_signature import InputKind
|
|
20
|
-
import torch.fx._pytree as fx_pytree
|
|
21
21
|
from torch.utils import _pytree as pytree
|
|
22
22
|
|
|
23
23
|
|
|
@@ -33,6 +33,15 @@ def exported_program_to_fx_graph_module_and_inputs(
|
|
|
33
33
|
|
|
34
34
|
@contextlib.contextmanager
|
|
35
35
|
def redirect_stdio(stdout, stderr):
|
|
36
|
+
"""Redirects stdout and stderr to the given file objects.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
stdout: A file object to redirect stdout to.
|
|
40
|
+
stderr: A file object to redirect stderr to.
|
|
41
|
+
|
|
42
|
+
Yields:
|
|
43
|
+
The file objects that stdout and stderr were redirected to.
|
|
44
|
+
"""
|
|
36
45
|
old_stdout = sys.stdout
|
|
37
46
|
old_stderr = sys.stderr
|
|
38
47
|
|
|
@@ -34,8 +34,8 @@ def convert_gemma_to_tflite(
|
|
|
34
34
|
quantize: bool = True,
|
|
35
35
|
):
|
|
36
36
|
"""An example method for converting a Gemma 2B model to multi-signature
|
|
37
|
-
tflite model.
|
|
38
37
|
|
|
38
|
+
tflite model.
|
|
39
39
|
Args:
|
|
40
40
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
41
41
|
holding the checkpoint.
|
|
@@ -43,8 +43,8 @@ def convert_gemma_to_tflite(
|
|
|
43
43
|
Defaults to 512.
|
|
44
44
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
45
45
|
including both prefill and decode. Defaults to 1024.
|
|
46
|
-
quantize (bool, optional): Whether the model should be quanized.
|
|
47
|
-
|
|
46
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
|
47
|
+
to True.
|
|
48
48
|
"""
|
|
49
49
|
pytorch_model = gemma.build_2b_model(
|
|
50
50
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
@@ -73,7 +73,9 @@ class Gemma(nn.Module):
|
|
|
73
73
|
)
|
|
74
74
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
75
75
|
size=config.kv_cache_max,
|
|
76
|
-
dim=int(
|
|
76
|
+
dim=int(
|
|
77
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
78
|
+
),
|
|
77
79
|
base=10_000,
|
|
78
80
|
condense_ratio=1,
|
|
79
81
|
dtype=torch.float32,
|
|
@@ -125,6 +127,7 @@ class Gemma(nn.Module):
|
|
|
125
127
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
126
128
|
attn_config = cfg.AttentionConfig(
|
|
127
129
|
num_heads=8,
|
|
130
|
+
head_dim=256,
|
|
128
131
|
num_query_groups=1,
|
|
129
132
|
rotary_percentage=1.0,
|
|
130
133
|
)
|