ai-edge-torch-nightly 0.3.0.dev20250211__py3-none-any.whl → 0.3.0.dev20250212__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +4 -6
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +1 -3
- ai_edge_torch/generative/utilities/converter.py +1 -1
- ai_edge_torch/odml_torch/debuginfo/__init__.py +7 -1
- ai_edge_torch/odml_torch/debuginfo/_build.py +22 -2
- ai_edge_torch/odml_torch/export.py +21 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250211.dist-info → ai_edge_torch_nightly-0.3.0.dev20250212.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250211.dist-info → ai_edge_torch_nightly-0.3.0.dev20250212.dist-info}/RECORD +12 -12
- {ai_edge_torch_nightly-0.3.0.dev20250211.dist-info → ai_edge_torch_nightly-0.3.0.dev20250212.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250211.dist-info → ai_edge_torch_nightly-0.3.0.dev20250212.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250211.dist-info → ai_edge_torch_nightly-0.3.0.dev20250212.dist-info}/top_level.txt +0 -0
@@ -56,11 +56,6 @@ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
|
56
56
|
1280,
|
57
57
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
58
58
|
)
|
59
|
-
_PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
|
60
|
-
'pixel_values_size',
|
61
|
-
[3, 224, 224],
|
62
|
-
'The size of prefill pixel values except the batch dimension.',
|
63
|
-
)
|
64
59
|
_QUANTIZE = flags.DEFINE_bool(
|
65
60
|
'quantize',
|
66
61
|
True,
|
@@ -75,12 +70,15 @@ def main(_):
|
|
75
70
|
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
76
71
|
)
|
77
72
|
|
73
|
+
config = pytorch_model.image_encoder.config.image_embedding
|
78
74
|
converter.convert_to_tflite(
|
79
75
|
pytorch_model,
|
80
76
|
output_path=_OUTPUT_PATH.value,
|
81
77
|
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
|
82
78
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
83
|
-
pixel_values_size=torch.Size(
|
79
|
+
pixel_values_size=torch.Size(
|
80
|
+
[1, config.channels, config.image_size, config.image_size]
|
81
|
+
),
|
84
82
|
quantize=_QUANTIZE.value,
|
85
83
|
config=pytorch_model.config.decoder_config,
|
86
84
|
export_config=ExportConfig(),
|
@@ -136,9 +136,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
136
136
|
image_embedding=image_embedding_config,
|
137
137
|
block_configs=block_config,
|
138
138
|
final_norm_config=norm_config,
|
139
|
-
|
140
|
-
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
|
141
|
-
# enable_hlfb=True,
|
139
|
+
enable_hlfb=True,
|
142
140
|
)
|
143
141
|
return config
|
144
142
|
|
@@ -145,7 +145,7 @@ def _export_helper(
|
|
145
145
|
prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
|
146
146
|
|
147
147
|
prefill_pixel_values = (
|
148
|
-
torch.full(
|
148
|
+
torch.full(pixel_values_size, 0, dtype=torch.float32)
|
149
149
|
if pixel_values_size
|
150
150
|
else None
|
151
151
|
)
|
@@ -12,5 +12,11 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
"""Debug info generation for ODML Torch."""
|
16
|
+
|
17
|
+
from . import _build
|
16
18
|
from ._op_polyfill import write_mlir_debuginfo_op
|
19
|
+
|
20
|
+
build_nodename_debuginfo = _build.build_nodename_debuginfo
|
21
|
+
build_mlir_file_debuginfo = _build.build_mlir_file_debuginfo
|
22
|
+
build_mlir_debuginfo = _build.build_mlir_debuginfo
|
@@ -40,10 +40,16 @@ def _get_canonical_filename(filename):
|
|
40
40
|
|
41
41
|
This should be factored out so that pattern is a global option that a user
|
42
42
|
can override.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
filename: The filename to canonicalize.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
The canonicalized filename.
|
43
49
|
"""
|
44
50
|
|
45
|
-
# TODO: We should add a config option to provide a regex to strip
|
46
|
-
# debug info. Currently absolute path is used.
|
51
|
+
# TODO(yijieyang): We should add a config option to provide a regex to strip
|
52
|
+
# from the debug info. Currently absolute path is used.
|
47
53
|
return filename
|
48
54
|
|
49
55
|
|
@@ -55,9 +61,23 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
|
|
55
61
|
|
56
62
|
# Note: This uses internal APIs and may break in the future.
|
57
63
|
pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
|
64
|
+
if pt_trace is None:
|
65
|
+
return None, None
|
58
66
|
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
|
59
67
|
|
60
68
|
|
69
|
+
def build_nodename_debuginfo(node: torch.fx.Node):
|
70
|
+
"""Build the fx node name for the given node's lowerings in MLIR."""
|
71
|
+
history = node.meta.get("from_node", [])
|
72
|
+
if not history:
|
73
|
+
return None
|
74
|
+
if len(history) > 1:
|
75
|
+
return history[1][0]
|
76
|
+
if hasattr(history[0], "name"): # torch 2.6.0+
|
77
|
+
return history[0].name
|
78
|
+
return None
|
79
|
+
|
80
|
+
|
61
81
|
def build_mlir_debuginfo(node: torch.fx.Node):
|
62
82
|
"""Build the debuginfo string for the given node's lowerings in MLIR."""
|
63
83
|
|
@@ -88,6 +88,21 @@ class LoweringInterpreter(torch.fx.Interpreter):
|
|
88
88
|
self.outputs = None
|
89
89
|
|
90
90
|
def _build_loc(self, node: torch.fx.Node):
|
91
|
+
"""Build MLIR location for the given node.
|
92
|
+
|
93
|
+
The location contains:
|
94
|
+
- layer info
|
95
|
+
- fx node name
|
96
|
+
- file and line info
|
97
|
+
|
98
|
+
Currently it's still under development and format is subject to change.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
node: The torch.fx.Node to build the location for.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
The MLIR location for the given node.
|
105
|
+
"""
|
91
106
|
|
92
107
|
info = debuginfo.build_mlir_debuginfo(node)
|
93
108
|
if info is None:
|
@@ -98,7 +113,12 @@ class LoweringInterpreter(torch.fx.Interpreter):
|
|
98
113
|
if file is not None:
|
99
114
|
fileinfo = ir.Location.file(filename=file, line=line, col=0)
|
100
115
|
|
101
|
-
|
116
|
+
node_name = debuginfo.build_nodename_debuginfo(node)
|
117
|
+
nodeinfo = None
|
118
|
+
if node_name is not None:
|
119
|
+
nodeinfo = ir.Location.name(name=node_name, childLoc=fileinfo)
|
120
|
+
|
121
|
+
return ir.Location.name(name=info, childLoc=nodeinfo)
|
102
122
|
|
103
123
|
def run_node(self, node: torch.fx.Node):
|
104
124
|
loc = self._build_loc(node)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250212
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=gYCO29a18zxFgm6z4IHCArhNE29IvlL_sxhEL6O4ECw,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -73,10 +73,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
|
|
73
73
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
|
74
74
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
75
75
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=APQymtr3n2k6-e8wvn3kVrli0qiElduYIkHeahcoSA0,2743
|
77
77
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
|
78
78
|
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
|
79
|
-
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=
|
79
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
|
80
80
|
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
|
81
81
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
|
82
82
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
@@ -172,7 +172,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728Fc
|
|
172
172
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
173
173
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
174
174
|
ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
|
175
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
175
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=K9taR0KY59dvfU_jO1yBe_p7w8lDns1Q3U6oJTTKZzM,8058
|
176
176
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
177
177
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
178
178
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
|
@@ -197,14 +197,14 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
|
|
197
197
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
198
198
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
199
199
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
200
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
200
|
+
ai_edge_torch/odml_torch/export.py,sha256=7l8R0DEq_vfns8iWpruMlIyaIKZAFzoAy369-7iRrl0,14164
|
201
201
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
202
202
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
203
203
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
204
204
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
205
205
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
206
|
-
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=
|
207
|
-
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=
|
206
|
+
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
|
207
|
+
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=JIMCn_fNh5-PgcV5qcklD7aFj0RhNKlvnZ-XQFCOszc,2706
|
208
208
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
209
209
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
210
210
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
@@ -229,8 +229,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
229
229
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
230
230
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
231
231
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
-
ai_edge_torch_nightly-0.3.0.
|
234
|
-
ai_edge_torch_nightly-0.3.0.
|
235
|
-
ai_edge_torch_nightly-0.3.0.
|
236
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
+
ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/METADATA,sha256=uCdhC0TvTD1j2cNaIsI1l_uB_FVHMml_jDr2vUzzKHw,1966
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
235
|
+
ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
236
|
+
ai_edge_torch_nightly-0.3.0.dev20250212.dist-info/RECORD,,
|
File without changes
|
File without changes
|