ai-edge-torch-nightly 0.3.0.dev20240905__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -129,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
129
129
  )
130
130
 
131
131
  copied_model = copy.deepcopy(pytorch_model)
132
+ copied_edge = copy.deepcopy(edge_model)
132
133
 
133
134
  self.assertTrue(
134
135
  model_coverage.compare_tflite_torch(
@@ -140,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
140
141
  )
141
142
  )
142
143
 
143
- # TODO(b/362840003): figure why this decode output has big numerical diff.
144
- skip_output_check = True
145
- if not skip_output_check:
146
- self.assertTrue(
147
- model_coverage.compare_tflite_torch(
148
- edge_model,
149
- copied_model,
150
- (decode_token, decode_input_pos),
151
- signature_name="decode",
152
- num_valid_inputs=1,
153
- )
154
- )
144
+ self.assertTrue(
145
+ model_coverage.compare_tflite_torch(
146
+ copied_edge,
147
+ copied_model,
148
+ (decode_token, decode_input_pos),
149
+ signature_name="decode",
150
+ num_valid_inputs=1,
151
+ )
152
+ )
155
153
 
156
154
 
157
155
  if __name__ == "__main__":
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
82
82
  model.eval()
83
83
 
84
84
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
85
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
- tokens[0, :4] = idx
87
- input_pos = torch.arange(0, 10)
85
+ prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
+ prefill_tokens[0, :4] = idx
87
+ prefill_input_pos = torch.arange(0, 10)
88
88
 
89
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
89
+ edge_model = ai_edge_torch.signature(
90
+ "prefill", model, (prefill_tokens, prefill_input_pos)
91
+ ).convert()
90
92
  edge_model.set_interpreter_builder(
91
93
  self._interpreter_builder(edge_model.tflite_model())
92
94
  )
93
95
 
94
- # TODO(b/362840003): debug numerical diff.
95
- skip_output_check = True
96
- if not skip_output_check:
97
- self.assertTrue(
98
- model_coverage.compare_tflite_torch(
99
- edge_model,
100
- model,
101
- (tokens, input_pos),
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
105
- )
106
- )
96
+ self.assertTrue(
97
+ model_coverage.compare_tflite_torch(
98
+ edge_model,
99
+ model,
100
+ (prefill_tokens, prefill_input_pos),
101
+ signature_name="prefill",
102
+ num_valid_inputs=1,
103
+ atol=1e-2,
104
+ rtol=1e-5,
105
+ )
106
+ )
107
107
 
108
108
  @googletest.skipIf(
109
109
  ai_edge_config.Config.use_torch_xla,
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
227
227
  exported_program: torch.export.ExportedProgram,
228
228
  ) -> MlirLowered:
229
229
  """Lower the exported program to MLIR."""
230
- if torch.__version__ >= "2.2":
231
- # torch version 2.1 didn't expose this yet
232
- exported_program = exported_program.run_decompositions()
233
- exported_program = exported_program.run_decompositions(
234
- lowerings.decompositions()
235
- )
230
+ exported_program = exported_program.run_decompositions(
231
+ lowerings.decompositions()
232
+ )
236
233
 
237
234
  with export_utils.create_ir_context() as context, ir.Location.unknown():
238
235
 
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
35
35
 
36
36
  def _lower_to_ir_text(
37
37
  jaxfn, args, kwargs, ir_input_names: list[str] = None
38
- ) -> str:
38
+ ) -> tuple[str, list[ir.Value]]:
39
39
  args = utils.tree_map_list_to_tuple(args)
40
40
  kwargs = utils.tree_map_list_to_tuple(kwargs)
41
41
 
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
74
74
  x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
75
75
  ]
76
76
 
77
- def new_lowering(*args, **jax_lower_static_kwargs):
77
+ def lower_wrapper(*args):
78
+ nonlocal jax_lower_static_kwargs
79
+
78
80
  jaxfn_args = []
79
81
  jaxfn_kwargs = jax_lower_static_kwargs.copy()
80
82
  for name, arg in zip(jax_lower_argnames, args):
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
85
87
 
86
88
  return jaxfn(*jaxfn_args, **jaxfn_kwargs)
87
89
 
88
- return (
89
- jax.jit(new_lowering, static_argnames=static_argnames)
90
- .lower(*jax_lower_args, **jax_lower_static_kwargs)
91
- .as_text()
92
- ), ir_inputs
90
+ return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
93
91
 
94
92
 
95
93
  def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
@@ -52,6 +52,7 @@ class LoweringRegistry:
52
52
 
53
53
 
54
54
  global_registry = LoweringRegistry()
55
+ global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
55
56
  global_registry.decompositions.update(
56
57
  torch._decomp.get_decompositions([
57
58
  torch.ops.aten.upsample_nearest2d,
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
70
71
  ])
71
72
  )
72
73
 
74
+ torch._decomp.remove_decompositions(
75
+ global_registry.decompositions,
76
+ [
77
+ torch.ops.aten.roll,
78
+ ],
79
+ )
80
+
73
81
 
74
82
  def lookup(op):
75
83
  return global_registry.lookup(op)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240905"
16
+ __version__ = "0.3.0.dev20240906"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240905
3
+ Version: 0.3.0.dev20240906
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=-vQGdl2EaV-VpHRty3RwZzH0UVntVt1tmjhtKOIDscw,706
5
+ ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -109,8 +109,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
109
109
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
110
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
111
111
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
112
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=KZ0uCeOdKMKyW8jBE8aOjweZmws4mvz37u8zH4XayVU,5285
113
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
112
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
113
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
114
114
  ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
115
115
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
116
116
  ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
@@ -133,7 +133,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGP
133
133
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
134
134
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
135
135
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
136
- ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
136
+ ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
137
137
  ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
138
138
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
139
139
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -143,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
143
143
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
144
144
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
145
145
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
146
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=hXvhKtbH7lGytm6QZOKpTmaLJN3kfENBcSIKQ39ReXA,5478
146
+ ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
147
147
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
148
148
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
149
149
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
@@ -151,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
151
151
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
152
152
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
153
153
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
154
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=dcnxq8vV9rxSQqXkjSg9it7l6oP_sdfH8kIZdQNkQ_4,2653
154
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
155
155
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
156
156
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
157
157
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/METADATA,sha256=8yrrm7TEYgaRhKdUwgStjCqrTWs8YcnnlzoTJt2NrJg,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,