ai-edge-torch-nightly 0.3.0.dev20250104__py3-none-any.whl → 0.3.0.dev20250106__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -155,6 +155,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
155
155
  @layout_sensitive_inputs_getters.register(
156
156
  aten._native_batch_norm_legit_no_training
157
157
  )
158
+ @layout_sensitive_inputs_getters.register(aten.group_norm)
158
159
  @layout_sensitive_inputs_getters.register(aten.native_group_norm)
159
160
  def _first_arg_getter(node):
160
161
  return [node.args[0]]
@@ -188,6 +189,14 @@ def _aten_norm_checker(node):
188
189
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
189
190
 
190
191
 
192
+ @nhwcable_node_checkers.register(aten.group_norm)
193
+ def _aten_group_norm_checker(node):
194
+ val = node.meta.get("val")
195
+ if not hasattr(val, "shape"):
196
+ return NHWCable(can_be=False, must_be=False)
197
+ return NHWCable(can_be=len(val.shape) == 4, must_be=False)
198
+
199
+
191
200
  @nhwcable_node_checkers.register(aten.native_group_norm)
192
201
  def _aten_native_group_norm_checker(node):
193
202
  val = node.meta.get("val")
@@ -342,6 +342,18 @@ def _aten__native_batch_norm_legit_no_training(node):
342
342
  node.target = batch_norm
343
343
 
344
344
 
345
+ @rewriters.register(aten.group_norm.default)
346
+ def _aten_group_norm(node):
347
+ def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
348
+ # Disable NHWC rewriter with native decomposied ops due to precision issue.
349
+ # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
350
+ input = utils.tensor_to_nchw(input)
351
+ res = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
352
+ return utils.tensor_to_nhwc(res)
353
+
354
+ node.target = group_norm
355
+
356
+
345
357
  @rewriters.register(aten.native_group_norm.default)
346
358
  def _aten_native_group_norm(node):
347
359
 
@@ -354,6 +366,7 @@ def _aten_native_group_norm(node):
354
366
  flattened_inner_size: int,
355
367
  num_groups: int,
356
368
  eps: float,
369
+ **kwargs,
357
370
  ):
358
371
  input_reshaped = torch.reshape(
359
372
  input,
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250104"
16
+ __version__ = "0.3.0.dev20250106"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250104
3
+ Version: 0.3.0.dev20250106
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=ogz9lERVLAHGNMBjJhihuAL8IMBQXbVb1X1FSVdQVcY,706
6
+ ai_edge_torch/version.py,sha256=l4ka-RcLl8uhiXMC9pfER4_jVWj1v3NJPPQZT4f7uZs,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4J
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
19
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
19
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=aXv8hvxHWOr5xiZIczAWdxPBV_3nEJVzJgMeNor55ps,7947
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=fe97b27EwWNnXxnphM6LY5CNNveJmlYNnAFND9X124Y,13239
22
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
23
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
203
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
204
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
205
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
206
- ai_edge_torch_nightly-0.3.0.dev20250104.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
- ai_edge_torch_nightly-0.3.0.dev20250104.dist-info/METADATA,sha256=oSZ-le6_wi5EtIq7fD6JeElUhQnrN6h9ydk51FK_8I8,1966
208
- ai_edge_torch_nightly-0.3.0.dev20250104.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
- ai_edge_torch_nightly-0.3.0.dev20250104.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
- ai_edge_torch_nightly-0.3.0.dev20250104.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/METADATA,sha256=GGqVNvcJLsaksphF-aI8acthK7eYJQ0vKg2BktlXDRQ,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20250106.dist-info/RECORD,,