ai-edge-torch-nightly 0.3.0.dev20240815__py3-none-any.whl → 0.3.0.dev20240816__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/hlfb/test/test_mark_pattern.py +32 -8
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +5 -0
- ai_edge_torch/lowertools/__init__.py +1 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +3 -0
- ai_edge_torch/lowertools/test_utils.py +60 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240815.dist-info → ai_edge_torch_nightly-0.3.0.dev20240816.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240815.dist-info → ai_edge_torch_nightly-0.3.0.dev20240816.dist-info}/RECORD +11 -10
- {ai_edge_torch_nightly-0.3.0.dev20240815.dist-info → ai_edge_torch_nightly-0.3.0.dev20240816.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240815.dist-info → ai_edge_torch_nightly-0.3.0.dev20240816.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240815.dist-info → ai_edge_torch_nightly-0.3.0.dev20240816.dist-info}/top_level.txt +0 -0
|
@@ -51,7 +51,12 @@ class TestMarkPattern(googletest.TestCase):
|
|
|
51
51
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
52
52
|
mlir = _export_stablehlo_mlir(exported_program)
|
|
53
53
|
|
|
54
|
-
|
|
54
|
+
lowertools.assert_string_count(
|
|
55
|
+
self,
|
|
56
|
+
mlir,
|
|
57
|
+
{'stablehlo.composite "test.add"': 2},
|
|
58
|
+
{"stablehlo.custom_call @mark_tensor": 6},
|
|
59
|
+
)
|
|
55
60
|
|
|
56
61
|
def test_mark_pattern_with_attr_builder(self):
|
|
57
62
|
class TestModel(torch.nn.Module):
|
|
@@ -72,9 +77,15 @@ class TestMarkPattern(googletest.TestCase):
|
|
|
72
77
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
73
78
|
mlir = _export_stablehlo_mlir(exported_program)
|
|
74
79
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
mlir
|
|
80
|
+
lowertools.assert_string_count(
|
|
81
|
+
self,
|
|
82
|
+
mlir,
|
|
83
|
+
{
|
|
84
|
+
'stablehlo.composite "test.add"': 2,
|
|
85
|
+
'composite_attributes = {alias = "test.test_add"}': 2,
|
|
86
|
+
},
|
|
87
|
+
{"stablehlo.custom_call @mark_tensor": 6},
|
|
88
|
+
{'{"alias": "test.test_add"}': 2},
|
|
78
89
|
)
|
|
79
90
|
|
|
80
91
|
def test_mark_pattern_with_scalar_attr_tracker(self):
|
|
@@ -104,9 +115,17 @@ class TestMarkPattern(googletest.TestCase):
|
|
|
104
115
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
105
116
|
mlir = _export_stablehlo_mlir(exported_program)
|
|
106
117
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
118
|
+
lowertools.assert_string_count(
|
|
119
|
+
self,
|
|
120
|
+
mlir,
|
|
121
|
+
{
|
|
122
|
+
'stablehlo.composite "test.log_softmax"': 5,
|
|
123
|
+
"composite_attributes = {dim = 0 : i64}": 3,
|
|
124
|
+
"composite_attributes = {dim = 1 : i64}": 2,
|
|
125
|
+
},
|
|
126
|
+
{"stablehlo.custom_call @mark_tensor": 10},
|
|
127
|
+
{'{"dim": 0}': 3, '{"dim": 1}': 2},
|
|
128
|
+
)
|
|
110
129
|
|
|
111
130
|
def test_mark_tangent_model_and_pattern_input(self):
|
|
112
131
|
class TestModel(torch.nn.Module):
|
|
@@ -128,7 +147,12 @@ class TestMarkPattern(googletest.TestCase):
|
|
|
128
147
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
|
129
148
|
mlir = _export_stablehlo_mlir(exported_program)
|
|
130
149
|
|
|
131
|
-
|
|
150
|
+
lowertools.assert_string_count(
|
|
151
|
+
self,
|
|
152
|
+
mlir,
|
|
153
|
+
{'stablehlo.composite "test.relu"': 1},
|
|
154
|
+
{"stablehlo.custom_call @mark_tensor": 2},
|
|
155
|
+
)
|
|
132
156
|
|
|
133
157
|
|
|
134
158
|
if __name__ == "__main__":
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import math
|
|
18
18
|
|
|
19
|
+
from ai_edge_torch import config
|
|
19
20
|
from ai_edge_torch import lowertools
|
|
20
21
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
21
22
|
import torch
|
|
@@ -29,6 +30,10 @@ def _export_stablehlo_mlir(model, args):
|
|
|
29
30
|
return lowertools.exported_program_to_mlir_text(ep)
|
|
30
31
|
|
|
31
32
|
|
|
33
|
+
@googletest.skipIf(
|
|
34
|
+
not config.Config.use_torch_xla,
|
|
35
|
+
reason="The odml_torch counter part is in odml_torch.",
|
|
36
|
+
)
|
|
32
37
|
class TestStableHLOCompositeBuilder(googletest.TestCase):
|
|
33
38
|
|
|
34
39
|
def test_build_composite(self):
|
|
@@ -28,6 +28,7 @@ import tensorflow as tf
|
|
|
28
28
|
import torch
|
|
29
29
|
|
|
30
30
|
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
31
|
+
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
|
|
31
32
|
|
|
32
33
|
MlirBundle = odml_torch.export.MlirLowered
|
|
33
34
|
|
|
@@ -162,7 +163,9 @@ def merged_bundle_to_tfl_model(
|
|
|
162
163
|
)
|
|
163
164
|
|
|
164
165
|
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
|
166
|
+
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
|
165
167
|
converter._experimental_enable_composite_direct_lowering = True
|
|
168
|
+
converter.model_origin_framework = "PYTORCH"
|
|
166
169
|
|
|
167
170
|
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
|
168
171
|
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
from typing import Optional
|
|
18
|
+
from ai_edge_torch import config
|
|
19
|
+
from tensorflow.python.platform import googletest
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _extract_backend_configs(mlir):
|
|
23
|
+
mlir = mlir.replace("\\22", '"')
|
|
24
|
+
configs = []
|
|
25
|
+
for match in re.finditer(r"backend_config\s*=\s*\"(\{.*\})\"", mlir):
|
|
26
|
+
configs.append(match.group(1))
|
|
27
|
+
return "\n".join(configs)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def assert_string_count(
|
|
31
|
+
test_case: googletest.TestCase,
|
|
32
|
+
mlir: str,
|
|
33
|
+
torch_xla_pattern_counter: dict[str, int],
|
|
34
|
+
odml_torch_pattern_counter: dict[str, int],
|
|
35
|
+
odml_torch_attr_counter: Optional[dict[str, int]] = None,
|
|
36
|
+
):
|
|
37
|
+
|
|
38
|
+
if odml_torch_attr_counter is None:
|
|
39
|
+
odml_torch_attr_counter = {}
|
|
40
|
+
|
|
41
|
+
if config.Config.use_torch_xla:
|
|
42
|
+
for key in torch_xla_pattern_counter:
|
|
43
|
+
test_case.assertEqual(
|
|
44
|
+
mlir.count(key),
|
|
45
|
+
torch_xla_pattern_counter[key],
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
for key in odml_torch_pattern_counter:
|
|
49
|
+
test_case.assertEqual(
|
|
50
|
+
mlir.count(key),
|
|
51
|
+
odml_torch_pattern_counter[key],
|
|
52
|
+
)
|
|
53
|
+
backend_configs = _extract_backend_configs(mlir)
|
|
54
|
+
print("backend_configs:")
|
|
55
|
+
print(backend_configs)
|
|
56
|
+
for key in odml_torch_attr_counter:
|
|
57
|
+
test_case.assertEqual(
|
|
58
|
+
backend_configs.count(key),
|
|
59
|
+
odml_torch_attr_counter[key],
|
|
60
|
+
)
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.3.0.
|
|
3
|
+
Version: 0.3.0.dev20240816
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
4
|
ai_edge_torch/model.py,sha256=5DYNpFVwvI1w0JbAC1hn83NJVGS1WPX7n742419PMqs,4558
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
5
|
+
ai_edge_torch/version.py,sha256=UxuGFkLOzBGL1eu0wNKnHSaqw7KhjNX3AGtSb5cAvRo,706
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
@@ -122,12 +122,13 @@ ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k
|
|
|
122
122
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
|
123
123
|
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=uiYRfzD1T8deCEAGfdAFusRbI41m14zeTt0Lz5lNT3M,9808
|
|
124
124
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
125
|
-
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=
|
|
126
|
-
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=
|
|
127
|
-
ai_edge_torch/lowertools/__init__.py,sha256=
|
|
125
|
+
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=e53YNSO2w7Sd9Y717jAr6WKjnXq34Tx_52hXRGtGs3A,4833
|
|
126
|
+
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=7Qbba7GJCBc-J1TUwWIvrpBK0Hwza9nift7sKpW2YVE,8449
|
|
127
|
+
ai_edge_torch/lowertools/__init__.py,sha256=uKGibEN7n4Tqbe0HiXOEEXWmPL9AUmh34xaYA9yx2sg,719
|
|
128
128
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
|
129
129
|
ai_edge_torch/lowertools/common_utils.py,sha256=emClsZ_MBlbLG_0BBtyLpkdz4dMWp6SyrNioygRBylk,2973
|
|
130
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
|
130
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=32cak8uiXFIVdkaYFhIW1fWG4NzLrYq-w8xK0pNkhYc,6547
|
|
131
|
+
ai_edge_torch/lowertools/test_utils.py,sha256=vsjaX3Ix2U1163jVUNSJgK9io2WNUtJjRvNFE9DrqF4,1932
|
|
131
132
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-g0NldtVOTCQtX3V2XEjuCQO_I52nSNQlu0r_rIS2IE,8635
|
|
132
133
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
|
133
134
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
|
@@ -136,8 +137,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
|
136
137
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
137
138
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
138
139
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
139
|
-
ai_edge_torch_nightly-0.3.0.
|
|
140
|
-
ai_edge_torch_nightly-0.3.0.
|
|
141
|
-
ai_edge_torch_nightly-0.3.0.
|
|
142
|
-
ai_edge_torch_nightly-0.3.0.
|
|
143
|
-
ai_edge_torch_nightly-0.3.0.
|
|
140
|
+
ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
141
|
+
ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/METADATA,sha256=c3BtvxGP-9OippYjgb8nO-_hfTW3hRVbPiFUH0ldb38,1885
|
|
142
|
+
ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
143
|
+
ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
144
|
+
ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|