ai-edge-torch-nightly 0.3.0.dev20250104__py3-none-any.whl → 0.3.0.dev20250106__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +9 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250104.dist-info → ai_edge_torch_nightly-0.3.0.dev20250106.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250104.dist-info → ai_edge_torch_nightly-0.3.0.dev20250106.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.3.0.dev20250104.dist-info → ai_edge_torch_nightly-0.3.0.dev20250106.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250104.dist-info → ai_edge_torch_nightly-0.3.0.dev20250106.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250104.dist-info → ai_edge_torch_nightly-0.3.0.dev20250106.dist-info}/top_level.txt +0 -0
@@ -155,6 +155,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
155
155
|
@layout_sensitive_inputs_getters.register(
|
156
156
|
aten._native_batch_norm_legit_no_training
|
157
157
|
)
|
158
|
+
@layout_sensitive_inputs_getters.register(aten.group_norm)
|
158
159
|
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
|
159
160
|
def _first_arg_getter(node):
|
160
161
|
return [node.args[0]]
|
@@ -188,6 +189,14 @@ def _aten_norm_checker(node):
|
|
188
189
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
189
190
|
|
190
191
|
|
192
|
+
@nhwcable_node_checkers.register(aten.group_norm)
|
193
|
+
def _aten_group_norm_checker(node):
|
194
|
+
val = node.meta.get("val")
|
195
|
+
if not hasattr(val, "shape"):
|
196
|
+
return NHWCable(can_be=False, must_be=False)
|
197
|
+
return NHWCable(can_be=len(val.shape) == 4, must_be=False)
|
198
|
+
|
199
|
+
|
191
200
|
@nhwcable_node_checkers.register(aten.native_group_norm)
|
192
201
|
def _aten_native_group_norm_checker(node):
|
193
202
|
val = node.meta.get("val")
|
@@ -342,6 +342,18 @@ def _aten__native_batch_norm_legit_no_training(node):
|
|
342
342
|
node.target = batch_norm
|
343
343
|
|
344
344
|
|
345
|
+
@rewriters.register(aten.group_norm.default)
|
346
|
+
def _aten_group_norm(node):
|
347
|
+
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
|
348
|
+
# Disable NHWC rewriter with native decomposied ops due to precision issue.
|
349
|
+
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
|
350
|
+
input = utils.tensor_to_nchw(input)
|
351
|
+
res = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
|
352
|
+
return utils.tensor_to_nhwc(res)
|
353
|
+
|
354
|
+
node.target = group_norm
|
355
|
+
|
356
|
+
|
345
357
|
@rewriters.register(aten.native_group_norm.default)
|
346
358
|
def _aten_native_group_norm(node):
|
347
359
|
|
@@ -354,6 +366,7 @@ def _aten_native_group_norm(node):
|
|
354
366
|
flattened_inner_size: int,
|
355
367
|
num_groups: int,
|
356
368
|
eps: float,
|
369
|
+
**kwargs,
|
357
370
|
):
|
358
371
|
input_reshaped = torch.reshape(
|
359
372
|
input,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250106
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=l4ka-RcLl8uhiXMC9pfER4_jVWj1v3NJPPQZT4f7uZs,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4J
|
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
19
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=aXv8hvxHWOr5xiZIczAWdxPBV_3nEJVzJgMeNor55ps,7947
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=fe97b27EwWNnXxnphM6LY5CNNveJmlYNnAFND9X124Y,13239
|
22
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
|
23
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
203
203
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
204
204
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
205
205
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/METADATA,sha256=GGqVNvcJLsaksphF-aI8acthK7eYJQ0vKg2BktlXDRQ,1966
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/RECORD,,
|
File without changes
|
File without changes
|