ai-edge-torch-nightly 0.3.0.dev20240905__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/test/test_model_conversion.py +10 -12
- ai_edge_torch/generative/test/test_model_conversion_large.py +17 -17
- ai_edge_torch/odml_torch/export.py +3 -6
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +5 -7
- ai_edge_torch/odml_torch/lowerings/registry.py +8 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/top_level.txt +0 -0
@@ -129,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
|
|
129
129
|
)
|
130
130
|
|
131
131
|
copied_model = copy.deepcopy(pytorch_model)
|
132
|
+
copied_edge = copy.deepcopy(edge_model)
|
132
133
|
|
133
134
|
self.assertTrue(
|
134
135
|
model_coverage.compare_tflite_torch(
|
@@ -140,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
|
|
140
141
|
)
|
141
142
|
)
|
142
143
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
num_valid_inputs=1,
|
153
|
-
)
|
154
|
-
)
|
144
|
+
self.assertTrue(
|
145
|
+
model_coverage.compare_tflite_torch(
|
146
|
+
copied_edge,
|
147
|
+
copied_model,
|
148
|
+
(decode_token, decode_input_pos),
|
149
|
+
signature_name="decode",
|
150
|
+
num_valid_inputs=1,
|
151
|
+
)
|
152
|
+
)
|
155
153
|
|
156
154
|
|
157
155
|
if __name__ == "__main__":
|
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
|
|
82
82
|
model.eval()
|
83
83
|
|
84
84
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
85
|
-
|
86
|
-
|
87
|
-
|
85
|
+
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
|
+
prefill_tokens[0, :4] = idx
|
87
|
+
prefill_input_pos = torch.arange(0, 10)
|
88
88
|
|
89
|
-
edge_model = ai_edge_torch.
|
89
|
+
edge_model = ai_edge_torch.signature(
|
90
|
+
"prefill", model, (prefill_tokens, prefill_input_pos)
|
91
|
+
).convert()
|
90
92
|
edge_model.set_interpreter_builder(
|
91
93
|
self._interpreter_builder(edge_model.tflite_model())
|
92
94
|
)
|
93
95
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
)
|
106
|
-
)
|
96
|
+
self.assertTrue(
|
97
|
+
model_coverage.compare_tflite_torch(
|
98
|
+
edge_model,
|
99
|
+
model,
|
100
|
+
(prefill_tokens, prefill_input_pos),
|
101
|
+
signature_name="prefill",
|
102
|
+
num_valid_inputs=1,
|
103
|
+
atol=1e-2,
|
104
|
+
rtol=1e-5,
|
105
|
+
)
|
106
|
+
)
|
107
107
|
|
108
108
|
@googletest.skipIf(
|
109
109
|
ai_edge_config.Config.use_torch_xla,
|
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
|
|
227
227
|
exported_program: torch.export.ExportedProgram,
|
228
228
|
) -> MlirLowered:
|
229
229
|
"""Lower the exported program to MLIR."""
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
exported_program = exported_program.run_decompositions(
|
234
|
-
lowerings.decompositions()
|
235
|
-
)
|
230
|
+
exported_program = exported_program.run_decompositions(
|
231
|
+
lowerings.decompositions()
|
232
|
+
)
|
236
233
|
|
237
234
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
238
235
|
|
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
|
|
35
35
|
|
36
36
|
def _lower_to_ir_text(
|
37
37
|
jaxfn, args, kwargs, ir_input_names: list[str] = None
|
38
|
-
) -> str:
|
38
|
+
) -> tuple[str, list[ir.Value]]:
|
39
39
|
args = utils.tree_map_list_to_tuple(args)
|
40
40
|
kwargs = utils.tree_map_list_to_tuple(kwargs)
|
41
41
|
|
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
|
|
74
74
|
x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
|
75
75
|
]
|
76
76
|
|
77
|
-
def
|
77
|
+
def lower_wrapper(*args):
|
78
|
+
nonlocal jax_lower_static_kwargs
|
79
|
+
|
78
80
|
jaxfn_args = []
|
79
81
|
jaxfn_kwargs = jax_lower_static_kwargs.copy()
|
80
82
|
for name, arg in zip(jax_lower_argnames, args):
|
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
|
|
85
87
|
|
86
88
|
return jaxfn(*jaxfn_args, **jaxfn_kwargs)
|
87
89
|
|
88
|
-
return (
|
89
|
-
jax.jit(new_lowering, static_argnames=static_argnames)
|
90
|
-
.lower(*jax_lower_args, **jax_lower_static_kwargs)
|
91
|
-
.as_text()
|
92
|
-
), ir_inputs
|
90
|
+
return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
|
93
91
|
|
94
92
|
|
95
93
|
def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
@@ -52,6 +52,7 @@ class LoweringRegistry:
|
|
52
52
|
|
53
53
|
|
54
54
|
global_registry = LoweringRegistry()
|
55
|
+
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
|
55
56
|
global_registry.decompositions.update(
|
56
57
|
torch._decomp.get_decompositions([
|
57
58
|
torch.ops.aten.upsample_nearest2d,
|
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
|
|
70
71
|
])
|
71
72
|
)
|
72
73
|
|
74
|
+
torch._decomp.remove_decompositions(
|
75
|
+
global_registry.decompositions,
|
76
|
+
[
|
77
|
+
torch.ops.aten.roll,
|
78
|
+
],
|
79
|
+
)
|
80
|
+
|
73
81
|
|
74
82
|
def lookup(op):
|
75
83
|
return global_registry.lookup(op)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240906
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256
|
5
|
+
ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -109,8 +109,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
109
109
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
110
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
|
111
111
|
ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
|
112
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
113
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
112
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
|
113
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
|
114
114
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
115
115
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
116
116
|
ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
|
@@ -133,7 +133,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGP
|
|
133
133
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
134
134
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
135
135
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
136
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
136
|
+
ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
|
137
137
|
ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
|
138
138
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
139
139
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -143,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
143
143
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
144
144
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
145
145
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
146
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
146
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
147
147
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
148
148
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
|
149
149
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
|
@@ -151,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
|
|
151
151
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
|
152
152
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
|
153
153
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
154
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
154
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
155
155
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
156
156
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
157
157
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
161
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
162
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
163
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
|
File without changes
|
File without changes
|