ai-edge-torch-nightly 0.3.0.dev20240814__py3-none-any.whl → 0.3.0.dev20240816__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -51,7 +51,12 @@ class TestMarkPattern(googletest.TestCase):
51
51
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
52
52
  mlir = _export_stablehlo_mlir(exported_program)
53
53
 
54
- self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
54
+ lowertools.assert_string_count(
55
+ self,
56
+ mlir,
57
+ {'stablehlo.composite "test.add"': 2},
58
+ {"stablehlo.custom_call @mark_tensor": 6},
59
+ )
55
60
 
56
61
  def test_mark_pattern_with_attr_builder(self):
57
62
  class TestModel(torch.nn.Module):
@@ -72,9 +77,15 @@ class TestMarkPattern(googletest.TestCase):
72
77
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
73
78
  mlir = _export_stablehlo_mlir(exported_program)
74
79
 
75
- self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
76
- self.assertEqual(
77
- mlir.count('composite_attributes = {alias = "test.test_add"}'), 2
80
+ lowertools.assert_string_count(
81
+ self,
82
+ mlir,
83
+ {
84
+ 'stablehlo.composite "test.add"': 2,
85
+ 'composite_attributes = {alias = "test.test_add"}': 2,
86
+ },
87
+ {"stablehlo.custom_call @mark_tensor": 6},
88
+ {'{"alias": "test.test_add"}': 2},
78
89
  )
79
90
 
80
91
  def test_mark_pattern_with_scalar_attr_tracker(self):
@@ -104,9 +115,17 @@ class TestMarkPattern(googletest.TestCase):
104
115
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
105
116
  mlir = _export_stablehlo_mlir(exported_program)
106
117
 
107
- self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 5)
108
- self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 3)
109
- self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 2)
118
+ lowertools.assert_string_count(
119
+ self,
120
+ mlir,
121
+ {
122
+ 'stablehlo.composite "test.log_softmax"': 5,
123
+ "composite_attributes = {dim = 0 : i64}": 3,
124
+ "composite_attributes = {dim = 1 : i64}": 2,
125
+ },
126
+ {"stablehlo.custom_call @mark_tensor": 10},
127
+ {'{"dim": 0}': 3, '{"dim": 1}': 2},
128
+ )
110
129
 
111
130
  def test_mark_tangent_model_and_pattern_input(self):
112
131
  class TestModel(torch.nn.Module):
@@ -128,7 +147,12 @@ class TestMarkPattern(googletest.TestCase):
128
147
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
129
148
  mlir = _export_stablehlo_mlir(exported_program)
130
149
 
131
- self.assertEqual(mlir.count('stablehlo.composite "test.relu'), 1)
150
+ lowertools.assert_string_count(
151
+ self,
152
+ mlir,
153
+ {'stablehlo.composite "test.relu"': 1},
154
+ {"stablehlo.custom_call @mark_tensor": 2},
155
+ )
132
156
 
133
157
 
134
158
  if __name__ == "__main__":
@@ -16,6 +16,7 @@
16
16
 
17
17
  import math
18
18
 
19
+ from ai_edge_torch import config
19
20
  from ai_edge_torch import lowertools
20
21
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
21
22
  import torch
@@ -29,6 +30,10 @@ def _export_stablehlo_mlir(model, args):
29
30
  return lowertools.exported_program_to_mlir_text(ep)
30
31
 
31
32
 
33
+ @googletest.skipIf(
34
+ not config.Config.use_torch_xla,
35
+ reason="The odml_torch counter part is in odml_torch.",
36
+ )
32
37
  class TestStableHLOCompositeBuilder(googletest.TestCase):
33
38
 
34
39
  def test_build_composite(self):
@@ -14,3 +14,4 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from ._shim import *
17
+ from .test_utils import *
@@ -28,6 +28,7 @@ import tensorflow as tf
28
28
  import torch
29
29
 
30
30
  from tensorflow.compiler.tf2xla.python import xla as tfxla
31
+ from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
31
32
 
32
33
  MlirBundle = odml_torch.export.MlirLowered
33
34
 
@@ -162,7 +163,9 @@ def merged_bundle_to_tfl_model(
162
163
  )
163
164
 
164
165
  converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
166
+ converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
165
167
  converter._experimental_enable_composite_direct_lowering = True
168
+ converter.model_origin_framework = "PYTORCH"
166
169
 
167
170
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
168
171
 
@@ -0,0 +1,60 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import re
17
+ from typing import Optional
18
+ from ai_edge_torch import config
19
+ from tensorflow.python.platform import googletest
20
+
21
+
22
+ def _extract_backend_configs(mlir):
23
+ mlir = mlir.replace("\\22", '"')
24
+ configs = []
25
+ for match in re.finditer(r"backend_config\s*=\s*\"(\{.*\})\"", mlir):
26
+ configs.append(match.group(1))
27
+ return "\n".join(configs)
28
+
29
+
30
+ def assert_string_count(
31
+ test_case: googletest.TestCase,
32
+ mlir: str,
33
+ torch_xla_pattern_counter: dict[str, int],
34
+ odml_torch_pattern_counter: dict[str, int],
35
+ odml_torch_attr_counter: Optional[dict[str, int]] = None,
36
+ ):
37
+
38
+ if odml_torch_attr_counter is None:
39
+ odml_torch_attr_counter = {}
40
+
41
+ if config.Config.use_torch_xla:
42
+ for key in torch_xla_pattern_counter:
43
+ test_case.assertEqual(
44
+ mlir.count(key),
45
+ torch_xla_pattern_counter[key],
46
+ )
47
+ else:
48
+ for key in odml_torch_pattern_counter:
49
+ test_case.assertEqual(
50
+ mlir.count(key),
51
+ odml_torch_pattern_counter[key],
52
+ )
53
+ backend_configs = _extract_backend_configs(mlir)
54
+ print("backend_configs:")
55
+ print(backend_configs)
56
+ for key in odml_torch_attr_counter:
57
+ test_case.assertEqual(
58
+ backend_configs.count(key),
59
+ odml_torch_attr_counter[key],
60
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240814"
16
+ __version__ = "0.3.0.dev20240816"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240814
3
+ Version: 0.3.0.dev20240816
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=5DYNpFVwvI1w0JbAC1hn83NJVGS1WPX7n742419PMqs,4558
5
- ai_edge_torch/version.py,sha256=BlH3JqkXwVHXFYAd5rF04dUvLCthvKVqnfgO3abgh14,706
5
+ ai_edge_torch/version.py,sha256=UxuGFkLOzBGL1eu0wNKnHSaqw7KhjNX3AGtSb5cAvRo,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -122,12 +122,13 @@ ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k
122
122
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
123
123
  ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=uiYRfzD1T8deCEAGfdAFusRbI41m14zeTt0Lz5lNT3M,9808
124
124
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
125
- ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=oYB0RPW-tHOwW9gQFH9GtHKO_Mmh1lkoiemXmTfySqc,4383
126
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=3vSX5E9ZFFhTPZZX6TMiAsGa_kaXABbN851bRbTFsC0,8297
127
- ai_edge_torch/lowertools/__init__.py,sha256=0M9TOR80sS5y6dikOsIFYx0P9IomqAdNIuYpgkP4PcI,693
125
+ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=e53YNSO2w7Sd9Y717jAr6WKjnXq34Tx_52hXRGtGs3A,4833
126
+ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=7Qbba7GJCBc-J1TUwWIvrpBK0Hwza9nift7sKpW2YVE,8449
127
+ ai_edge_torch/lowertools/__init__.py,sha256=uKGibEN7n4Tqbe0HiXOEEXWmPL9AUmh34xaYA9yx2sg,719
128
128
  ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
129
129
  ai_edge_torch/lowertools/common_utils.py,sha256=emClsZ_MBlbLG_0BBtyLpkdz4dMWp6SyrNioygRBylk,2973
130
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=5XVI2ovptp1wsDZcyyaZDgT4oUa1McOiE-PrKhXNhFo,6316
130
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=32cak8uiXFIVdkaYFhIW1fWG4NzLrYq-w8xK0pNkhYc,6547
131
+ ai_edge_torch/lowertools/test_utils.py,sha256=vsjaX3Ix2U1163jVUNSJgK9io2WNUtJjRvNFE9DrqF4,1932
131
132
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-g0NldtVOTCQtX3V2XEjuCQO_I52nSNQlu0r_rIS2IE,8635
132
133
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
133
134
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
@@ -136,8 +137,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
136
137
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
137
138
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
138
139
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
139
- ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
140
- ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/METADATA,sha256=eYXq0PpFouGnXKu9vXIzyaXj8XsLDxlDn903GJFR3ak,1885
141
- ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
142
- ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
143
- ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/RECORD,,
140
+ ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
141
+ ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/METADATA,sha256=c3BtvxGP-9OippYjgb8nO-_hfTW3hRVbPiFUH0ldb38,1885
142
+ ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
143
+ ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
144
+ ai_edge_torch_nightly-0.3.0.dev20240816.dist-info/RECORD,,