ai-edge-torch-nightly 0.3.0.dev20240905__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -129,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
129
129
  )
130
130
 
131
131
  copied_model = copy.deepcopy(pytorch_model)
132
+ copied_edge = copy.deepcopy(edge_model)
132
133
 
133
134
  self.assertTrue(
134
135
  model_coverage.compare_tflite_torch(
@@ -140,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
140
141
  )
141
142
  )
142
143
 
143
- # TODO(b/362840003): figure why this decode output has big numerical diff.
144
- skip_output_check = True
145
- if not skip_output_check:
146
- self.assertTrue(
147
- model_coverage.compare_tflite_torch(
148
- edge_model,
149
- copied_model,
150
- (decode_token, decode_input_pos),
151
- signature_name="decode",
152
- num_valid_inputs=1,
153
- )
154
- )
144
+ self.assertTrue(
145
+ model_coverage.compare_tflite_torch(
146
+ copied_edge,
147
+ copied_model,
148
+ (decode_token, decode_input_pos),
149
+ signature_name="decode",
150
+ num_valid_inputs=1,
151
+ )
152
+ )
155
153
 
156
154
 
157
155
  if __name__ == "__main__":
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
82
82
  model.eval()
83
83
 
84
84
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
85
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
- tokens[0, :4] = idx
87
- input_pos = torch.arange(0, 10)
85
+ prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
+ prefill_tokens[0, :4] = idx
87
+ prefill_input_pos = torch.arange(0, 10)
88
88
 
89
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
89
+ edge_model = ai_edge_torch.signature(
90
+ "prefill", model, (prefill_tokens, prefill_input_pos)
91
+ ).convert()
90
92
  edge_model.set_interpreter_builder(
91
93
  self._interpreter_builder(edge_model.tflite_model())
92
94
  )
93
95
 
94
- # TODO(b/362840003): debug numerical diff.
95
- skip_output_check = True
96
- if not skip_output_check:
97
- self.assertTrue(
98
- model_coverage.compare_tflite_torch(
99
- edge_model,
100
- model,
101
- (tokens, input_pos),
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
105
- )
106
- )
96
+ self.assertTrue(
97
+ model_coverage.compare_tflite_torch(
98
+ edge_model,
99
+ model,
100
+ (prefill_tokens, prefill_input_pos),
101
+ signature_name="prefill",
102
+ num_valid_inputs=1,
103
+ atol=1e-2,
104
+ rtol=1e-5,
105
+ )
106
+ )
107
107
 
108
108
  @googletest.skipIf(
109
109
  ai_edge_config.Config.use_torch_xla,
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
227
227
  exported_program: torch.export.ExportedProgram,
228
228
  ) -> MlirLowered:
229
229
  """Lower the exported program to MLIR."""
230
- if torch.__version__ >= "2.2":
231
- # torch version 2.1 didn't expose this yet
232
- exported_program = exported_program.run_decompositions()
233
- exported_program = exported_program.run_decompositions(
234
- lowerings.decompositions()
235
- )
230
+ exported_program = exported_program.run_decompositions(
231
+ lowerings.decompositions()
232
+ )
236
233
 
237
234
  with export_utils.create_ir_context() as context, ir.Location.unknown():
238
235
 
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
35
35
 
36
36
  def _lower_to_ir_text(
37
37
  jaxfn, args, kwargs, ir_input_names: list[str] = None
38
- ) -> str:
38
+ ) -> tuple[str, list[ir.Value]]:
39
39
  args = utils.tree_map_list_to_tuple(args)
40
40
  kwargs = utils.tree_map_list_to_tuple(kwargs)
41
41
 
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
74
74
  x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
75
75
  ]
76
76
 
77
- def new_lowering(*args, **jax_lower_static_kwargs):
77
+ def lower_wrapper(*args):
78
+ nonlocal jax_lower_static_kwargs
79
+
78
80
  jaxfn_args = []
79
81
  jaxfn_kwargs = jax_lower_static_kwargs.copy()
80
82
  for name, arg in zip(jax_lower_argnames, args):
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
85
87
 
86
88
  return jaxfn(*jaxfn_args, **jaxfn_kwargs)
87
89
 
88
- return (
89
- jax.jit(new_lowering, static_argnames=static_argnames)
90
- .lower(*jax_lower_args, **jax_lower_static_kwargs)
91
- .as_text()
92
- ), ir_inputs
90
+ return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
93
91
 
94
92
 
95
93
  def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
@@ -52,6 +52,7 @@ class LoweringRegistry:
52
52
 
53
53
 
54
54
  global_registry = LoweringRegistry()
55
+ global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
55
56
  global_registry.decompositions.update(
56
57
  torch._decomp.get_decompositions([
57
58
  torch.ops.aten.upsample_nearest2d,
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
70
71
  ])
71
72
  )
72
73
 
74
+ torch._decomp.remove_decompositions(
75
+ global_registry.decompositions,
76
+ [
77
+ torch.ops.aten.roll,
78
+ ],
79
+ )
80
+
73
81
 
74
82
  def lookup(op):
75
83
  return global_registry.lookup(op)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240905"
16
+ __version__ = "0.3.0.dev20240906"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240905
3
+ Version: 0.3.0.dev20240906
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=-vQGdl2EaV-VpHRty3RwZzH0UVntVt1tmjhtKOIDscw,706
5
+ ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -109,8 +109,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
109
109
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
110
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
111
111
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
112
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=KZ0uCeOdKMKyW8jBE8aOjweZmws4mvz37u8zH4XayVU,5285
113
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
112
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
113
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
114
114
  ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
115
115
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
116
116
  ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
@@ -133,7 +133,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGP
133
133
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
134
134
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
135
135
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
136
- ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
136
+ ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
137
137
  ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
138
138
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
139
139
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -143,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
143
143
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
144
144
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
145
145
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
146
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=hXvhKtbH7lGytm6QZOKpTmaLJN3kfENBcSIKQ39ReXA,5478
146
+ ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
147
147
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
148
148
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
149
149
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
@@ -151,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
151
151
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
152
152
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
153
153
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
154
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=dcnxq8vV9rxSQqXkjSg9it7l6oP_sdfH8kIZdQNkQ_4,2653
154
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
155
155
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
156
156
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
157
157
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/METADATA,sha256=8yrrm7TEYgaRhKdUwgStjCqrTWs8YcnnlzoTJt2NrJg,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,