ai-edge-torch-nightly 0.2.0.dev20240805__py3-none-any.whl → 0.2.0.dev20240808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (104) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
  88. ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.2.0.dev20240808.dist-info/RECORD +141 -0
  97. ai_edge_torch/convert/conversion_utils.py +0 -439
  98. ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +0 -133
  99. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  101. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/LICENSE +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/WHEEL +0 -0
  104. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/top_level.txt +0 -0
@@ -594,11 +594,13 @@ class AutoEncoderModelLoader(BaseLoader):
594
594
  up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
595
595
 
596
596
  def __init__(self, file_name: str, names: TensorNames):
597
- """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
597
+ """AutoEncoderModelLoader constructor.
598
+
599
+ Can be used to load encoder and decoder models.
598
600
 
599
601
  Args:
600
- file_name (str): Path to the checkpoint. Can be a directory or an
601
- exact file.
602
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
603
+ file.
602
604
  names (TensorNames): An instance of `TensorNames` to determine mappings.
603
605
  """
604
606
  self._file_name = file_name
@@ -617,7 +619,8 @@ class AutoEncoderModelLoader(BaseLoader):
617
619
 
618
620
  Returns:
619
621
  missing_keys (List[str]): a list of str containing the missing keys.
620
- unexpected_keys (List[str]): a list of str containing the unexpected keys.
622
+ unexpected_keys (List[str]): a list of str containing the unexpected
623
+ keys.
621
624
 
622
625
  Raises:
623
626
  ValueError: If conversion results in unmapped tensors and strict mode is
@@ -683,6 +686,31 @@ class AutoEncoderModelLoader(BaseLoader):
683
686
  return model.load_state_dict(converted_state, strict=strict)
684
687
 
685
688
 
689
+ def build_attention_config(
690
+ num_heads,
691
+ dim,
692
+ num_query_groups,
693
+ rotary_percentage=0.0,
694
+ qkv_transpose_before_split=True,
695
+ qkv_use_bias=False,
696
+ output_proj_use_bias=True,
697
+ enable_kv_cache=False,
698
+ qkv_fused_interleaved=False,
699
+ ):
700
+
701
+ return layers_config.AttentionConfig(
702
+ num_heads=num_heads,
703
+ head_dim=dim // num_heads,
704
+ num_query_groups=num_query_groups,
705
+ rotary_percentage=rotary_percentage,
706
+ qkv_transpose_before_split=qkv_transpose_before_split,
707
+ qkv_use_bias=qkv_use_bias,
708
+ output_proj_use_bias=output_proj_use_bias,
709
+ enable_kv_cache=enable_kv_cache,
710
+ qkv_fused_interleaved=qkv_fused_interleaved,
711
+ )
712
+
713
+
686
714
  class DiffusionModelLoader(BaseLoader):
687
715
 
688
716
  @dataclass
@@ -696,11 +724,13 @@ class DiffusionModelLoader(BaseLoader):
696
724
  up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
697
725
 
698
726
  def __init__(self, file_name: str, names: TensorNames):
699
- """DiffusionModelLoader constructor. Can be used to load diffusion models of Stable Diffusion.
727
+ """DiffusionModelLoader constructor.
728
+
729
+ Can be used to load diffusion models of Stable Diffusion.
700
730
 
701
731
  Args:
702
- file_name (str): Path to the checkpoint. Can be a directory or an
703
- exact file.
732
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
733
+ file.
704
734
  names (TensorNames): An instance of `TensorNames` to determine mappings.
705
735
  """
706
736
  self._file_name = file_name
@@ -719,7 +749,8 @@ class DiffusionModelLoader(BaseLoader):
719
749
 
720
750
  Returns:
721
751
  missing_keys (List[str]): a list of str containing the missing keys.
722
- unexpected_keys (List[str]): a list of str containing the unexpected keys.
752
+ unexpected_keys (List[str]): a list of str containing the unexpected
753
+ keys.
723
754
 
724
755
  Raises:
725
756
  ValueError: If conversion results in unmapped tensors and strict mode is
@@ -741,16 +772,6 @@ class DiffusionModelLoader(BaseLoader):
741
772
  state, self._names.final_norm, converted_state, "final_norm"
742
773
  )
743
774
 
744
- attention_config = layers_config.AttentionConfig(
745
- num_heads=config.transformer_num_attention_heads,
746
- num_query_groups=config.transformer_num_attention_heads,
747
- rotary_percentage=0.0,
748
- qkv_transpose_before_split=True,
749
- qkv_use_bias=False,
750
- output_proj_use_bias=True,
751
- enable_kv_cache=False,
752
- )
753
-
754
775
  # Map down_encoders.
755
776
  output_channel = config.block_out_channels[0]
756
777
  for i, block_out_channel in enumerate(config.block_out_channels):
@@ -781,13 +802,21 @@ class DiffusionModelLoader(BaseLoader):
781
802
  attention_block_config=unet_config.AttentionBlock2DConfig(
782
803
  dim=output_channel,
783
804
  normalization_config=config.transformer_norm_config,
784
- attention_config=attention_config,
805
+ attention_config=build_attention_config(
806
+ num_heads=config.transformer_num_attention_heads,
807
+ dim=output_channel,
808
+ num_query_groups=config.transformer_num_attention_heads,
809
+ ),
785
810
  ),
786
811
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
787
812
  query_dim=output_channel,
788
813
  cross_dim=config.transformer_cross_attention_dim,
789
814
  normalization_config=config.transformer_norm_config,
790
- attention_config=attention_config,
815
+ attention_config=build_attention_config(
816
+ num_heads=config.transformer_num_attention_heads,
817
+ dim=output_channel,
818
+ num_query_groups=config.transformer_num_attention_heads,
819
+ ),
791
820
  ),
792
821
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
793
822
  feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
@@ -839,13 +868,21 @@ class DiffusionModelLoader(BaseLoader):
839
868
  attention_block_config=unet_config.AttentionBlock2DConfig(
840
869
  dim=mid_block_channels,
841
870
  normalization_config=config.transformer_norm_config,
842
- attention_config=attention_config,
871
+ attention_config=build_attention_config(
872
+ num_heads=config.transformer_num_attention_heads,
873
+ dim=mid_block_channels,
874
+ num_query_groups=config.transformer_num_attention_heads,
875
+ ),
843
876
  ),
844
877
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
845
878
  query_dim=mid_block_channels,
846
879
  cross_dim=config.transformer_cross_attention_dim,
847
880
  normalization_config=config.transformer_norm_config,
848
- attention_config=attention_config,
881
+ attention_config=build_attention_config(
882
+ num_heads=config.transformer_num_attention_heads,
883
+ dim=mid_block_channels,
884
+ num_query_groups=config.transformer_num_attention_heads,
885
+ ),
849
886
  ),
850
887
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
851
888
  feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
@@ -904,13 +941,21 @@ class DiffusionModelLoader(BaseLoader):
904
941
  attention_block_config=unet_config.AttentionBlock2DConfig(
905
942
  dim=output_channel,
906
943
  normalization_config=config.transformer_norm_config,
907
- attention_config=attention_config,
944
+ attention_config=build_attention_config(
945
+ num_heads=config.transformer_num_attention_heads,
946
+ dim=output_channel,
947
+ num_query_groups=config.transformer_num_attention_heads,
948
+ ),
908
949
  ),
909
950
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
910
951
  query_dim=output_channel,
911
952
  cross_dim=config.transformer_cross_attention_dim,
912
953
  normalization_config=config.transformer_norm_config,
913
- attention_config=attention_config,
954
+ attention_config=build_attention_config(
955
+ num_heads=config.transformer_num_attention_heads,
956
+ dim=output_channel,
957
+ num_query_groups=config.transformer_num_attention_heads,
958
+ ),
914
959
  ),
915
960
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
916
961
  feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
@@ -92,9 +92,7 @@ def load_pytorch_statedict(full_path: str):
92
92
 
93
93
 
94
94
  class ModelLoader:
95
- """A utility class for loading and converting model checkpoints to ODML
96
- transformer layer format.
97
- """
95
+ """Utility class for loading and converting checkpoints to ODML transformer layer format."""
98
96
 
99
97
  @dataclass
100
98
  class TensorNames:
@@ -121,12 +119,13 @@ class ModelLoader:
121
119
  lm_head: str = None
122
120
 
123
121
  def __init__(self, file_name: str, names: TensorNames) -> None:
124
- """ModelLoader constructor. Can be used to load multiple models of the same
125
- type.
122
+ """ModelLoader constructor.
123
+
124
+ Can be used to load multiple models of the same type.
126
125
 
127
126
  Args:
128
- file_name (str): Path to the checkpoint. Can be a directory or an
129
- exact file.
127
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
128
+ file.
130
129
  names (TensorNames): An instance of `TensorNames` to determine mappings.
131
130
  """
132
131
  self._file_name = file_name
@@ -158,7 +157,7 @@ class ModelLoader:
158
157
  )
159
158
  elif isinstance(self._names, dict):
160
159
  converted_state = {}
161
- for additional_prefix, names in self._names.items():
160
+ for additional_prefix, _ in self._names.items():
162
161
  local_converted_state = self._do_load(
163
162
  model,
164
163
  state,
@@ -212,7 +211,7 @@ class ModelLoader:
212
211
 
213
212
  if names.relative_attn_bias:
214
213
  rel_attn_name = names.relative_attn_bias
215
- prefix = additional_prefix + f"transformer_blocks.0"
214
+ prefix = additional_prefix + "transformer_blocks.0"
216
215
  converted_state[f"{prefix}.atten_func.relative_attention_bias.weight"] = (
217
216
  state.pop(f"{rel_attn_name}.weight")
218
217
  )
@@ -266,7 +265,7 @@ class ModelLoader:
266
265
  if self._file_name.endswith(".bin"):
267
266
  return load_pytorch_statedict
268
267
 
269
- raise ValueError(f"File format not supported.")
268
+ raise ValueError("File format not supported.")
270
269
 
271
270
  def _map_feedforward(
272
271
  self,
@@ -505,8 +504,8 @@ class ModelLoader:
505
504
  q_per_kv = (
506
505
  config.attn_config.num_heads // config.attn_config.num_query_groups
507
506
  )
508
- qs = torch.split(q, config.head_dim * q_per_kv)
509
- ks = torch.split(k, config.head_dim)
510
- vs = torch.split(v, config.head_dim)
507
+ qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
508
+ ks = torch.split(k, config.attn_config.head_dim)
509
+ vs = torch.split(v, config.attn_config.head_dim)
511
510
  cycled = [t for group in zip(qs, ks, vs) for t in group]
512
511
  return torch.cat(cycled)
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
16
+ from ai_edge_torch.lowertools import StableHLOCompositeBuilder
@@ -16,10 +16,10 @@ import copy
16
16
  from typing import Any
17
17
  import uuid
18
18
 
19
- from ai_edge_torch.hlfb.mark_pattern.pattern import Pattern
20
- from ai_edge_torch.hlfb.mark_pattern.pattern import ScalarAttrTracker # NOQA
19
+ from ai_edge_torch import lowertools
20
+ from ai_edge_torch.hlfb.mark_pattern import passes
21
+ from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
21
22
  import torch
22
- from torch_xla.experimental import xla_marker
23
23
 
24
24
 
25
25
  @torch._dynamo.assume_constant_result
@@ -48,10 +48,10 @@ def _insert_marker(
48
48
  is_input: bool,
49
49
  attr: dict[str, Any] = None,
50
50
  ):
51
- attr = xla_marker.serialize_composite_attr(attr) if attr else None
51
+ attr = lowertools.serialize_composite_attr(attr) if attr else None
52
52
  with graph_module.graph.inserting_after(node):
53
53
  new_node = graph_module.graph.call_function(
54
- torch.ops.xla.mark_tensor,
54
+ lowertools.mark_tensor_op,
55
55
  args=(node,),
56
56
  kwargs={
57
57
  "name": name,
@@ -68,13 +68,16 @@ def _insert_marker(
68
68
 
69
69
  def mark_pattern(
70
70
  graph_module: torch.fx.GraphModule,
71
- pattern: Pattern,
71
+ pattern: pattern_module.Pattern,
72
72
  ) -> torch.fx.GraphModule:
73
73
  """Mark all existences of pattern graph in the GraphModule with fx pattern matching.
74
+
74
75
  The marked subgraphs will be lowered in StableHLO composite ops.
76
+
75
77
  Args:
76
78
  graph_module (torch.fx.GraphModule): GraphModule to be matched and marked.
77
79
  pattern (ai_edge_torch.hlfb.mark_pattern.Pattern): Pattern to match.
80
+
78
81
  Returns:
79
82
  The modified graph_module with additional marker ops in graph.
80
83
  """
@@ -12,13 +12,25 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Passes to clean up the model graph for pattern matching."""
16
+
15
17
  import torch
16
18
 
17
19
 
18
20
  def remove_clone_ops(gm: torch.fx.GraphModule):
19
- # torch export adds additional aten.clone nodes to produce contiguous in memory tensors
20
- # depending on tensor sizes for runtime efficiency. However, these unpredictable clone
21
- # nodes can break the pattern matching. Thus remove all clones in model and pattern graphs.
21
+ """Removes clone ops from the graph.
22
+
23
+ torch export adds additional aten.clone nodes to produce contiguous in memory
24
+ tensors depending on tensor sizes for runtime efficiency. However, these
25
+ unpredictable clone nodes can break the pattern matching. Thus remove all
26
+ clones in model and pattern graphs.
27
+
28
+ Args:
29
+ gm: The graph module to remove clone ops from.
30
+
31
+ Returns:
32
+ The graph module with clone ops removed.
33
+ """
22
34
  for node in gm.graph.nodes:
23
35
  if node.op == "call_function" and node.name.startswith("clone"):
24
36
  node.replace_all_uses_with(node.args[0])
@@ -30,6 +42,14 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
30
42
 
31
43
 
32
44
  def remove_dangling_args(gm: torch.fx.GraphModule):
45
+ """Removes dangling args from the graph.
46
+
47
+ Args:
48
+ gm: The graph module to remove dangling args from.
49
+
50
+ Returns:
51
+ The graph module with dangling args removed.
52
+ """
33
53
  nodes_to_erase = []
34
54
  for node in gm.graph.nodes:
35
55
  if node.op == "placeholder" and len(node.users) == 0:
@@ -12,7 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import copy
15
+ """Mark pattern."""
16
+
16
17
  import dataclasses
17
18
  from typing import Any, Callable, Optional, Union
18
19
 
@@ -45,6 +46,7 @@ def _are_equal(x: Any, y: Any) -> bool:
45
46
  @dataclasses.dataclass
46
47
  class ScalarAttrTracker:
47
48
  """ScalarAttrTracker is used to track the occurrence of a pattern's
49
+
48
50
  scalar arg/attr in the pattern decomposed graph. Since a scalar attr
49
51
  to the pattern can be transformed and turned into a/some ops' scalar
50
52
  arg in the decomposed graph, it would be hard to programmatically get
@@ -57,11 +59,10 @@ class ScalarAttrTracker:
57
59
  pattern_arg_pos (int): the index of the attr to track in the pattern's
58
60
  export_args.
59
61
  transform (Callable): the transform function used when targeting the
60
- occurrence of the attr value in the decomposed graph. An attr value
61
- may be transformed during the decomposition and appear as a derived
62
- value.
63
- inverse_transform (Callable): the inverse transform function that maps
64
- the transformed value back to the original attr value.
62
+ occurrence of the attr value in the decomposed graph. An attr value may be
63
+ transformed during the decomposition and appear as a derived value.
64
+ inverse_transform (Callable): the inverse transform function that maps the
65
+ transformed value back to the original attr value.
65
66
  """
66
67
 
67
68
  attr_name: str
@@ -74,6 +75,7 @@ class ScalarAttrTracker:
74
75
 
75
76
  def track(self, *sources):
76
77
  """Register magic values to track the (transformed) attr values in
78
+
77
79
  the pattern decomposed graph.
78
80
  """
79
81
  for source in sources:
@@ -158,24 +160,22 @@ class Pattern:
158
160
  """The PyTorch computation pattern to match against a model.
159
161
 
160
162
  Args:
161
- name (str): the name of the pattern. It would be propagated to
162
- the `name` attr in StableHLO composite ops for the matched
163
- model subgraphs in the lowering.
163
+ name (str): the name of the pattern. It would be propagated to the `name`
164
+ attr in StableHLO composite ops for the matched model subgraphs in the
165
+ lowering.
164
166
  module (torch.nn.Module or Callable): the PyTorch computation.
165
- export_args (tuple[Any]): the args used to export the pattern module
166
- with torch.export.export. If export_args contains non-tensor
167
- Python scalars, there must be a corresponding attr tracker
168
- in `scalar_attr_trackers` for each scalar arg.
169
- attr_builder (Callable[[Pattern, GraphModule, InternalMatch], Optional[dict[str, Any]]]):
170
- the callable that produces the a scalar attrs dict, which would be
171
- propagated to `attr` in StableHLO composite ops for the matched
172
- model subgraphs in the lowering.
173
- scalar_attr_trackers (list[ScalarAttrTracker]): the trackers
174
- for scalar args in `export_args`, which are used to track
175
- the attr occurrence(s) and retrieve their values from the
176
- matched subgraph.
177
- decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
178
- The decomposition table to be run on the pattern's exported program.
167
+ export_args (tuple[Any]): the args used to export the pattern module with
168
+ torch.export.export. If export_args contains non-tensor Python scalars,
169
+ there must be a corresponding attr tracker in `scalar_attr_trackers` for
170
+ each scalar arg. attr_builder (Callable[[Pattern, GraphModule,
171
+ InternalMatch], Optional[dict[str, Any]]]): the callable that produces
172
+ the a scalar attrs dict, which would be propagated to `attr` in
173
+ StableHLO composite ops for the matched model subgraphs in the lowering.
174
+ scalar_attr_trackers (list[ScalarAttrTracker]): the trackers for scalar
175
+ args in `export_args`, which are used to track the attr occurrence(s)
176
+ and retrieve their values from the matched subgraph.
177
+ decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]): The
178
+ decomposition table to be run on the pattern's exported program.
179
179
  """
180
180
  if not isinstance(module, torch.nn.Module):
181
181
 
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Tests for mark_pattern."""
15
16
 
16
- import unittest
17
-
17
+ from ai_edge_torch import lowertools
18
18
  from ai_edge_torch.hlfb import mark_pattern
19
+ from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
19
20
  import torch
20
- import torch_xla
21
+
22
+ from tensorflow.python.platform import googletest
21
23
 
22
24
 
23
25
  def _export_stablehlo_mlir(model, args=None):
@@ -25,11 +27,10 @@ def _export_stablehlo_mlir(model, args=None):
25
27
  ep = torch.export.export(model, args)
26
28
  else:
27
29
  ep = model
28
- stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
29
- return stablehlo_gm.get_stablehlo_text()
30
+ return lowertools.exported_program_to_mlir_text(ep)
30
31
 
31
32
 
32
- class TestMarkPattern(unittest.TestCase):
33
+ class TestMarkPattern(googletest.TestCase):
33
34
 
34
35
  def test_mark_pattern(self):
35
36
 
@@ -38,7 +39,7 @@ class TestMarkPattern(unittest.TestCase):
38
39
  def forward(self, x):
39
40
  return x * x + x + x
40
41
 
41
- pattern = mark_pattern.Pattern(
42
+ pattern = pattern_module.Pattern(
42
43
  "test.add",
43
44
  lambda a, b: a + b,
44
45
  export_args=(torch.rand(2, 2), torch.rand(2, 2)),
@@ -58,7 +59,7 @@ class TestMarkPattern(unittest.TestCase):
58
59
  def forward(self, x):
59
60
  return x * x * x + x - x * x + x
60
61
 
61
- pattern = mark_pattern.Pattern(
62
+ pattern = pattern_module.Pattern(
62
63
  "test.add",
63
64
  lambda a, b: a + b,
64
65
  export_args=(torch.rand(2, 2), torch.rand(2, 2)),
@@ -85,12 +86,12 @@ class TestMarkPattern(unittest.TestCase):
85
86
  r = torch.nn.LogSoftmax(dim=idx % 2)(r) * x
86
87
  return r
87
88
 
88
- pattern = mark_pattern.Pattern(
89
+ pattern = pattern_module.Pattern(
89
90
  "test.log_softmax",
90
91
  lambda x, dim: torch.nn.functional.log_softmax(x, dim=dim),
91
92
  export_args=(torch.rand(10, 10, 10), 1),
92
93
  scalar_attr_trackers=[
93
- mark_pattern.ScalarAttrTracker("dim", pattern_arg_pos=1)
94
+ pattern_module.ScalarAttrTracker("dim", pattern_arg_pos=1)
94
95
  .track(0)
95
96
  .track(1)
96
97
  .track(2),
@@ -115,7 +116,7 @@ class TestMarkPattern(unittest.TestCase):
115
116
  z = z + y
116
117
  return z
117
118
 
118
- pattern = mark_pattern.Pattern(
119
+ pattern = pattern_module.Pattern(
119
120
  "test.relu",
120
121
  lambda x: torch.ops.aten.relu(x),
121
122
  export_args=(torch.rand(2, 2),),
@@ -131,4 +132,4 @@ class TestMarkPattern(unittest.TestCase):
131
132
 
132
133
 
133
134
  if __name__ == "__main__":
134
- unittest.main()
135
+ googletest.main()
@@ -12,22 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Tests for StableHLOCompositeBuilder."""
16
+
15
17
  import math
16
- import unittest
17
18
 
19
+ from ai_edge_torch import lowertools
18
20
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
19
21
  import torch
20
22
  import torch.nn.functional as F
21
- import torch_xla
23
+
24
+ from tensorflow.python.platform import googletest
22
25
 
23
26
 
24
27
  def _export_stablehlo_mlir(model, args):
25
28
  ep = torch.export.export(model, args)
26
- stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
27
- return stablehlo_gm.get_stablehlo_text()
29
+ return lowertools.exported_program_to_mlir_text(ep)
28
30
 
29
31
 
30
- class TestStableHLOCompositeBuilder(unittest.TestCase):
32
+ class TestStableHLOCompositeBuilder(googletest.TestCase):
31
33
 
32
34
  def test_build_composite(self):
33
35
  class SampleModel(torch.nn.Module):
@@ -273,4 +275,4 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
273
275
 
274
276
 
275
277
  if __name__ == "__main__":
276
- unittest.main()
278
+ googletest.main()
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
16
+ from ._shim import *
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Any, Optional
17
+
18
+ from ai_edge_torch import config
19
+ from ai_edge_torch._convert import signature
20
+ from ai_edge_torch.quantize import quant_config as qcfg
21
+ import torch
22
+
23
+ # isort: off
24
+ if config.Config.use_torch_xla:
25
+ from ai_edge_torch.lowertools import torch_xla_utils as utils
26
+ from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
27
+ from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
28
+ from torch_xla.experimental.xla_marker import serialize_composite_attr
29
+ # The following imports are needed to register the needed torch_xla ops.
30
+ import torch_xla.experimental.xla_marker
31
+ import torch_xla.experimental.xla_mlir_debuginfo
32
+
33
+ mark_tensor_op = torch.ops.xla.mark_tensor.default
34
+ write_mlir_debuginfo_op = torch.ops.xla.write_mlir_debuginfo.default
35
+ else:
36
+ from ai_edge_torch.lowertools import odml_torch_utils as utils
37
+ from ai_edge_torch.lowertools.odml_torch_utils import exported_program_to_mlir_text
38
+ from ai_edge_torch.odml_torch.composite import StableHLOCompositeBuilder
39
+ from ai_edge_torch.odml_torch.composite.mark_tensor import serialize_composite_attr
40
+ from ai_edge_torch.odml_torch.composite.mark_tensor import mark_tensor_op
41
+ from ai_edge_torch.odml_torch.debuginfo import write_mlir_debuginfo_op
42
+ # isort: on
43
+
44
+
45
+ def exported_programs_to_tflite(
46
+ exported_programs: list[torch.export.ExportedProgram],
47
+ signatures: list[signature.Signature],
48
+ *,
49
+ quant_config: Optional[qcfg.QuantConfig] = None,
50
+ _tfl_converter_flags: Optional[dict[str, Any]] = None,
51
+ ):
52
+ """Converts a list of ExportedProgram to a TFLite model.
53
+
54
+ Args:
55
+ exported_programs: A list of ExportedProgram.
56
+ signatures: A list of Signature.
57
+ quant_config: A QuantConfig.
58
+ _tfl_converter_flags: A dict of flags for TFLiteConverter.
59
+
60
+ Returns:
61
+ A TFLite model.
62
+ """
63
+ if _tfl_converter_flags is None:
64
+ _tfl_converter_flags = {}
65
+
66
+ bundles: list[utils.MlirBundle] = [
67
+ utils.exported_program_to_mlir(exported, sig.flat_args)
68
+ for exported, sig in zip(exported_programs, signatures)
69
+ ]
70
+
71
+ merged_bundle: utils.MergedBundle = utils.merge_mlir_bundles(
72
+ bundles, signatures, exported_programs
73
+ )
74
+
75
+ return utils.merged_bundle_to_tfl_model(
76
+ merged_bundle,
77
+ signatures,
78
+ quant_config=quant_config,
79
+ _tfl_converter_flags=_tfl_converter_flags,
80
+ )