ai-edge-torch-nightly 0.3.0.dev20240905__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/generative/test/test_model_conversion.py +10 -12
- ai_edge_torch/generative/test/test_model_conversion_large.py +17 -17
- ai_edge_torch/odml_torch/export.py +3 -6
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +5 -7
- ai_edge_torch/odml_torch/lowerings/registry.py +8 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240905.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/top_level.txt +0 -0
|
@@ -129,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
|
|
|
129
129
|
)
|
|
130
130
|
|
|
131
131
|
copied_model = copy.deepcopy(pytorch_model)
|
|
132
|
+
copied_edge = copy.deepcopy(edge_model)
|
|
132
133
|
|
|
133
134
|
self.assertTrue(
|
|
134
135
|
model_coverage.compare_tflite_torch(
|
|
@@ -140,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
|
|
|
140
141
|
)
|
|
141
142
|
)
|
|
142
143
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
num_valid_inputs=1,
|
|
153
|
-
)
|
|
154
|
-
)
|
|
144
|
+
self.assertTrue(
|
|
145
|
+
model_coverage.compare_tflite_torch(
|
|
146
|
+
copied_edge,
|
|
147
|
+
copied_model,
|
|
148
|
+
(decode_token, decode_input_pos),
|
|
149
|
+
signature_name="decode",
|
|
150
|
+
num_valid_inputs=1,
|
|
151
|
+
)
|
|
152
|
+
)
|
|
155
153
|
|
|
156
154
|
|
|
157
155
|
if __name__ == "__main__":
|
|
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
|
|
|
82
82
|
model.eval()
|
|
83
83
|
|
|
84
84
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
85
|
+
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
86
|
+
prefill_tokens[0, :4] = idx
|
|
87
|
+
prefill_input_pos = torch.arange(0, 10)
|
|
88
88
|
|
|
89
|
-
edge_model = ai_edge_torch.
|
|
89
|
+
edge_model = ai_edge_torch.signature(
|
|
90
|
+
"prefill", model, (prefill_tokens, prefill_input_pos)
|
|
91
|
+
).convert()
|
|
90
92
|
edge_model.set_interpreter_builder(
|
|
91
93
|
self._interpreter_builder(edge_model.tflite_model())
|
|
92
94
|
)
|
|
93
95
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
)
|
|
106
|
-
)
|
|
96
|
+
self.assertTrue(
|
|
97
|
+
model_coverage.compare_tflite_torch(
|
|
98
|
+
edge_model,
|
|
99
|
+
model,
|
|
100
|
+
(prefill_tokens, prefill_input_pos),
|
|
101
|
+
signature_name="prefill",
|
|
102
|
+
num_valid_inputs=1,
|
|
103
|
+
atol=1e-2,
|
|
104
|
+
rtol=1e-5,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
107
|
|
|
108
108
|
@googletest.skipIf(
|
|
109
109
|
ai_edge_config.Config.use_torch_xla,
|
|
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
|
|
|
227
227
|
exported_program: torch.export.ExportedProgram,
|
|
228
228
|
) -> MlirLowered:
|
|
229
229
|
"""Lower the exported program to MLIR."""
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
exported_program = exported_program.run_decompositions(
|
|
234
|
-
lowerings.decompositions()
|
|
235
|
-
)
|
|
230
|
+
exported_program = exported_program.run_decompositions(
|
|
231
|
+
lowerings.decompositions()
|
|
232
|
+
)
|
|
236
233
|
|
|
237
234
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
|
238
235
|
|
|
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
|
|
|
35
35
|
|
|
36
36
|
def _lower_to_ir_text(
|
|
37
37
|
jaxfn, args, kwargs, ir_input_names: list[str] = None
|
|
38
|
-
) -> str:
|
|
38
|
+
) -> tuple[str, list[ir.Value]]:
|
|
39
39
|
args = utils.tree_map_list_to_tuple(args)
|
|
40
40
|
kwargs = utils.tree_map_list_to_tuple(kwargs)
|
|
41
41
|
|
|
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
|
|
|
74
74
|
x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
|
|
75
75
|
]
|
|
76
76
|
|
|
77
|
-
def
|
|
77
|
+
def lower_wrapper(*args):
|
|
78
|
+
nonlocal jax_lower_static_kwargs
|
|
79
|
+
|
|
78
80
|
jaxfn_args = []
|
|
79
81
|
jaxfn_kwargs = jax_lower_static_kwargs.copy()
|
|
80
82
|
for name, arg in zip(jax_lower_argnames, args):
|
|
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
|
|
|
85
87
|
|
|
86
88
|
return jaxfn(*jaxfn_args, **jaxfn_kwargs)
|
|
87
89
|
|
|
88
|
-
return (
|
|
89
|
-
jax.jit(new_lowering, static_argnames=static_argnames)
|
|
90
|
-
.lower(*jax_lower_args, **jax_lower_static_kwargs)
|
|
91
|
-
.as_text()
|
|
92
|
-
), ir_inputs
|
|
90
|
+
return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
|
|
93
91
|
|
|
94
92
|
|
|
95
93
|
def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
|
@@ -52,6 +52,7 @@ class LoweringRegistry:
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
global_registry = LoweringRegistry()
|
|
55
|
+
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
|
|
55
56
|
global_registry.decompositions.update(
|
|
56
57
|
torch._decomp.get_decompositions([
|
|
57
58
|
torch.ops.aten.upsample_nearest2d,
|
|
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
|
|
|
70
71
|
])
|
|
71
72
|
)
|
|
72
73
|
|
|
74
|
+
torch._decomp.remove_decompositions(
|
|
75
|
+
global_registry.decompositions,
|
|
76
|
+
[
|
|
77
|
+
torch.ops.aten.roll,
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
|
|
73
81
|
|
|
74
82
|
def lookup(op):
|
|
75
83
|
return global_registry.lookup(op)
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.3.0.
|
|
3
|
+
Version: 0.3.0.dev20240906
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
|
5
|
-
ai_edge_torch/version.py,sha256
|
|
5
|
+
ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
@@ -109,8 +109,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
|
109
109
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
110
110
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
|
|
111
111
|
ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
|
|
112
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
|
113
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
|
112
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
|
|
113
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
|
|
114
114
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
|
115
115
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
116
116
|
ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
|
|
@@ -133,7 +133,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGP
|
|
|
133
133
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
|
134
134
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
|
135
135
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
|
136
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
|
136
|
+
ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
|
|
137
137
|
ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
|
|
138
138
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
|
139
139
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
|
@@ -143,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
|
143
143
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
|
144
144
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
|
145
145
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
|
146
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
|
146
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
|
147
147
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
|
148
148
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
|
|
149
149
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
|
|
@@ -151,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
|
|
|
151
151
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
|
|
152
152
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
|
|
153
153
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
|
154
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
|
154
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
|
155
155
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
|
156
156
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
|
157
157
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
|
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
|
161
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
162
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
163
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
|
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|