ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.3.0.dev20240809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.3.0.dev20240809.dist-info/RECORD +141 -0
- ai_edge_torch/convert/conversion_utils.py +0 -439
- ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/top_level.txt +0 -0
|
@@ -594,11 +594,13 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
|
594
594
|
up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
|
|
595
595
|
|
|
596
596
|
def __init__(self, file_name: str, names: TensorNames):
|
|
597
|
-
"""AutoEncoderModelLoader constructor.
|
|
597
|
+
"""AutoEncoderModelLoader constructor.
|
|
598
|
+
|
|
599
|
+
Can be used to load encoder and decoder models.
|
|
598
600
|
|
|
599
601
|
Args:
|
|
600
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
|
601
|
-
|
|
602
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
|
603
|
+
file.
|
|
602
604
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
|
603
605
|
"""
|
|
604
606
|
self._file_name = file_name
|
|
@@ -617,7 +619,8 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
|
617
619
|
|
|
618
620
|
Returns:
|
|
619
621
|
missing_keys (List[str]): a list of str containing the missing keys.
|
|
620
|
-
unexpected_keys (List[str]): a list of str containing the unexpected
|
|
622
|
+
unexpected_keys (List[str]): a list of str containing the unexpected
|
|
623
|
+
keys.
|
|
621
624
|
|
|
622
625
|
Raises:
|
|
623
626
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
@@ -683,6 +686,31 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
|
683
686
|
return model.load_state_dict(converted_state, strict=strict)
|
|
684
687
|
|
|
685
688
|
|
|
689
|
+
def build_attention_config(
|
|
690
|
+
num_heads,
|
|
691
|
+
dim,
|
|
692
|
+
num_query_groups,
|
|
693
|
+
rotary_percentage=0.0,
|
|
694
|
+
qkv_transpose_before_split=True,
|
|
695
|
+
qkv_use_bias=False,
|
|
696
|
+
output_proj_use_bias=True,
|
|
697
|
+
enable_kv_cache=False,
|
|
698
|
+
qkv_fused_interleaved=False,
|
|
699
|
+
):
|
|
700
|
+
|
|
701
|
+
return layers_config.AttentionConfig(
|
|
702
|
+
num_heads=num_heads,
|
|
703
|
+
head_dim=dim // num_heads,
|
|
704
|
+
num_query_groups=num_query_groups,
|
|
705
|
+
rotary_percentage=rotary_percentage,
|
|
706
|
+
qkv_transpose_before_split=qkv_transpose_before_split,
|
|
707
|
+
qkv_use_bias=qkv_use_bias,
|
|
708
|
+
output_proj_use_bias=output_proj_use_bias,
|
|
709
|
+
enable_kv_cache=enable_kv_cache,
|
|
710
|
+
qkv_fused_interleaved=qkv_fused_interleaved,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
|
|
686
714
|
class DiffusionModelLoader(BaseLoader):
|
|
687
715
|
|
|
688
716
|
@dataclass
|
|
@@ -696,11 +724,13 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
696
724
|
up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
|
|
697
725
|
|
|
698
726
|
def __init__(self, file_name: str, names: TensorNames):
|
|
699
|
-
"""DiffusionModelLoader constructor.
|
|
727
|
+
"""DiffusionModelLoader constructor.
|
|
728
|
+
|
|
729
|
+
Can be used to load diffusion models of Stable Diffusion.
|
|
700
730
|
|
|
701
731
|
Args:
|
|
702
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
|
703
|
-
|
|
732
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
|
733
|
+
file.
|
|
704
734
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
|
705
735
|
"""
|
|
706
736
|
self._file_name = file_name
|
|
@@ -719,7 +749,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
719
749
|
|
|
720
750
|
Returns:
|
|
721
751
|
missing_keys (List[str]): a list of str containing the missing keys.
|
|
722
|
-
unexpected_keys (List[str]): a list of str containing the unexpected
|
|
752
|
+
unexpected_keys (List[str]): a list of str containing the unexpected
|
|
753
|
+
keys.
|
|
723
754
|
|
|
724
755
|
Raises:
|
|
725
756
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
@@ -741,16 +772,6 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
741
772
|
state, self._names.final_norm, converted_state, "final_norm"
|
|
742
773
|
)
|
|
743
774
|
|
|
744
|
-
attention_config = layers_config.AttentionConfig(
|
|
745
|
-
num_heads=config.transformer_num_attention_heads,
|
|
746
|
-
num_query_groups=config.transformer_num_attention_heads,
|
|
747
|
-
rotary_percentage=0.0,
|
|
748
|
-
qkv_transpose_before_split=True,
|
|
749
|
-
qkv_use_bias=False,
|
|
750
|
-
output_proj_use_bias=True,
|
|
751
|
-
enable_kv_cache=False,
|
|
752
|
-
)
|
|
753
|
-
|
|
754
775
|
# Map down_encoders.
|
|
755
776
|
output_channel = config.block_out_channels[0]
|
|
756
777
|
for i, block_out_channel in enumerate(config.block_out_channels):
|
|
@@ -781,13 +802,21 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
781
802
|
attention_block_config=unet_config.AttentionBlock2DConfig(
|
|
782
803
|
dim=output_channel,
|
|
783
804
|
normalization_config=config.transformer_norm_config,
|
|
784
|
-
attention_config=
|
|
805
|
+
attention_config=build_attention_config(
|
|
806
|
+
num_heads=config.transformer_num_attention_heads,
|
|
807
|
+
dim=output_channel,
|
|
808
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
809
|
+
),
|
|
785
810
|
),
|
|
786
811
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
|
787
812
|
query_dim=output_channel,
|
|
788
813
|
cross_dim=config.transformer_cross_attention_dim,
|
|
789
814
|
normalization_config=config.transformer_norm_config,
|
|
790
|
-
attention_config=
|
|
815
|
+
attention_config=build_attention_config(
|
|
816
|
+
num_heads=config.transformer_num_attention_heads,
|
|
817
|
+
dim=output_channel,
|
|
818
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
819
|
+
),
|
|
791
820
|
),
|
|
792
821
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
793
822
|
feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
|
|
@@ -839,13 +868,21 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
839
868
|
attention_block_config=unet_config.AttentionBlock2DConfig(
|
|
840
869
|
dim=mid_block_channels,
|
|
841
870
|
normalization_config=config.transformer_norm_config,
|
|
842
|
-
attention_config=
|
|
871
|
+
attention_config=build_attention_config(
|
|
872
|
+
num_heads=config.transformer_num_attention_heads,
|
|
873
|
+
dim=mid_block_channels,
|
|
874
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
875
|
+
),
|
|
843
876
|
),
|
|
844
877
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
|
845
878
|
query_dim=mid_block_channels,
|
|
846
879
|
cross_dim=config.transformer_cross_attention_dim,
|
|
847
880
|
normalization_config=config.transformer_norm_config,
|
|
848
|
-
attention_config=
|
|
881
|
+
attention_config=build_attention_config(
|
|
882
|
+
num_heads=config.transformer_num_attention_heads,
|
|
883
|
+
dim=mid_block_channels,
|
|
884
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
885
|
+
),
|
|
849
886
|
),
|
|
850
887
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
851
888
|
feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
|
|
@@ -904,13 +941,21 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
904
941
|
attention_block_config=unet_config.AttentionBlock2DConfig(
|
|
905
942
|
dim=output_channel,
|
|
906
943
|
normalization_config=config.transformer_norm_config,
|
|
907
|
-
attention_config=
|
|
944
|
+
attention_config=build_attention_config(
|
|
945
|
+
num_heads=config.transformer_num_attention_heads,
|
|
946
|
+
dim=output_channel,
|
|
947
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
948
|
+
),
|
|
908
949
|
),
|
|
909
950
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
|
910
951
|
query_dim=output_channel,
|
|
911
952
|
cross_dim=config.transformer_cross_attention_dim,
|
|
912
953
|
normalization_config=config.transformer_norm_config,
|
|
913
|
-
attention_config=
|
|
954
|
+
attention_config=build_attention_config(
|
|
955
|
+
num_heads=config.transformer_num_attention_heads,
|
|
956
|
+
dim=output_channel,
|
|
957
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
958
|
+
),
|
|
914
959
|
),
|
|
915
960
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
916
961
|
feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
|
|
@@ -92,9 +92,7 @@ def load_pytorch_statedict(full_path: str):
|
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
class ModelLoader:
|
|
95
|
-
"""
|
|
96
|
-
transformer layer format.
|
|
97
|
-
"""
|
|
95
|
+
"""Utility class for loading and converting checkpoints to ODML transformer layer format."""
|
|
98
96
|
|
|
99
97
|
@dataclass
|
|
100
98
|
class TensorNames:
|
|
@@ -121,12 +119,13 @@ class ModelLoader:
|
|
|
121
119
|
lm_head: str = None
|
|
122
120
|
|
|
123
121
|
def __init__(self, file_name: str, names: TensorNames) -> None:
|
|
124
|
-
"""ModelLoader constructor.
|
|
125
|
-
|
|
122
|
+
"""ModelLoader constructor.
|
|
123
|
+
|
|
124
|
+
Can be used to load multiple models of the same type.
|
|
126
125
|
|
|
127
126
|
Args:
|
|
128
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
|
129
|
-
|
|
127
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
|
128
|
+
file.
|
|
130
129
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
|
131
130
|
"""
|
|
132
131
|
self._file_name = file_name
|
|
@@ -158,7 +157,7 @@ class ModelLoader:
|
|
|
158
157
|
)
|
|
159
158
|
elif isinstance(self._names, dict):
|
|
160
159
|
converted_state = {}
|
|
161
|
-
for additional_prefix,
|
|
160
|
+
for additional_prefix, _ in self._names.items():
|
|
162
161
|
local_converted_state = self._do_load(
|
|
163
162
|
model,
|
|
164
163
|
state,
|
|
@@ -212,7 +211,7 @@ class ModelLoader:
|
|
|
212
211
|
|
|
213
212
|
if names.relative_attn_bias:
|
|
214
213
|
rel_attn_name = names.relative_attn_bias
|
|
215
|
-
prefix = additional_prefix +
|
|
214
|
+
prefix = additional_prefix + "transformer_blocks.0"
|
|
216
215
|
converted_state[f"{prefix}.atten_func.relative_attention_bias.weight"] = (
|
|
217
216
|
state.pop(f"{rel_attn_name}.weight")
|
|
218
217
|
)
|
|
@@ -266,7 +265,7 @@ class ModelLoader:
|
|
|
266
265
|
if self._file_name.endswith(".bin"):
|
|
267
266
|
return load_pytorch_statedict
|
|
268
267
|
|
|
269
|
-
raise ValueError(
|
|
268
|
+
raise ValueError("File format not supported.")
|
|
270
269
|
|
|
271
270
|
def _map_feedforward(
|
|
272
271
|
self,
|
|
@@ -505,8 +504,8 @@ class ModelLoader:
|
|
|
505
504
|
q_per_kv = (
|
|
506
505
|
config.attn_config.num_heads // config.attn_config.num_query_groups
|
|
507
506
|
)
|
|
508
|
-
qs = torch.split(q, config.head_dim * q_per_kv)
|
|
509
|
-
ks = torch.split(k, config.head_dim)
|
|
510
|
-
vs = torch.split(v, config.head_dim)
|
|
507
|
+
qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
|
|
508
|
+
ks = torch.split(k, config.attn_config.head_dim)
|
|
509
|
+
vs = torch.split(v, config.attn_config.head_dim)
|
|
511
510
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
|
512
511
|
return torch.cat(cycled)
|
ai_edge_torch/hlfb/__init__.py
CHANGED
|
@@ -13,4 +13,4 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from ai_edge_torch.lowertools import StableHLOCompositeBuilder
|
|
@@ -16,10 +16,10 @@ import copy
|
|
|
16
16
|
from typing import Any
|
|
17
17
|
import uuid
|
|
18
18
|
|
|
19
|
-
from ai_edge_torch
|
|
20
|
-
from ai_edge_torch.hlfb.mark_pattern
|
|
19
|
+
from ai_edge_torch import lowertools
|
|
20
|
+
from ai_edge_torch.hlfb.mark_pattern import passes
|
|
21
|
+
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
|
21
22
|
import torch
|
|
22
|
-
from torch_xla.experimental import xla_marker
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@torch._dynamo.assume_constant_result
|
|
@@ -48,10 +48,10 @@ def _insert_marker(
|
|
|
48
48
|
is_input: bool,
|
|
49
49
|
attr: dict[str, Any] = None,
|
|
50
50
|
):
|
|
51
|
-
attr =
|
|
51
|
+
attr = lowertools.serialize_composite_attr(attr) if attr else None
|
|
52
52
|
with graph_module.graph.inserting_after(node):
|
|
53
53
|
new_node = graph_module.graph.call_function(
|
|
54
|
-
|
|
54
|
+
lowertools.mark_tensor_op,
|
|
55
55
|
args=(node,),
|
|
56
56
|
kwargs={
|
|
57
57
|
"name": name,
|
|
@@ -68,13 +68,16 @@ def _insert_marker(
|
|
|
68
68
|
|
|
69
69
|
def mark_pattern(
|
|
70
70
|
graph_module: torch.fx.GraphModule,
|
|
71
|
-
pattern: Pattern,
|
|
71
|
+
pattern: pattern_module.Pattern,
|
|
72
72
|
) -> torch.fx.GraphModule:
|
|
73
73
|
"""Mark all existences of pattern graph in the GraphModule with fx pattern matching.
|
|
74
|
+
|
|
74
75
|
The marked subgraphs will be lowered in StableHLO composite ops.
|
|
76
|
+
|
|
75
77
|
Args:
|
|
76
78
|
graph_module (torch.fx.GraphModule): GraphModule to be matched and marked.
|
|
77
79
|
pattern (ai_edge_torch.hlfb.mark_pattern.Pattern): Pattern to match.
|
|
80
|
+
|
|
78
81
|
Returns:
|
|
79
82
|
The modified graph_module with additional marker ops in graph.
|
|
80
83
|
"""
|
|
@@ -12,13 +12,25 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Passes to clean up the model graph for pattern matching."""
|
|
16
|
+
|
|
15
17
|
import torch
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
21
|
+
"""Removes clone ops from the graph.
|
|
22
|
+
|
|
23
|
+
torch export adds additional aten.clone nodes to produce contiguous in memory
|
|
24
|
+
tensors depending on tensor sizes for runtime efficiency. However, these
|
|
25
|
+
unpredictable clone nodes can break the pattern matching. Thus remove all
|
|
26
|
+
clones in model and pattern graphs.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
gm: The graph module to remove clone ops from.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The graph module with clone ops removed.
|
|
33
|
+
"""
|
|
22
34
|
for node in gm.graph.nodes:
|
|
23
35
|
if node.op == "call_function" and node.name.startswith("clone"):
|
|
24
36
|
node.replace_all_uses_with(node.args[0])
|
|
@@ -30,6 +42,14 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
|
30
42
|
|
|
31
43
|
|
|
32
44
|
def remove_dangling_args(gm: torch.fx.GraphModule):
|
|
45
|
+
"""Removes dangling args from the graph.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
gm: The graph module to remove dangling args from.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The graph module with dangling args removed.
|
|
52
|
+
"""
|
|
33
53
|
nodes_to_erase = []
|
|
34
54
|
for node in gm.graph.nodes:
|
|
35
55
|
if node.op == "placeholder" and len(node.users) == 0:
|
|
@@ -12,7 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
15
|
+
"""Mark pattern."""
|
|
16
|
+
|
|
16
17
|
import dataclasses
|
|
17
18
|
from typing import Any, Callable, Optional, Union
|
|
18
19
|
|
|
@@ -45,6 +46,7 @@ def _are_equal(x: Any, y: Any) -> bool:
|
|
|
45
46
|
@dataclasses.dataclass
|
|
46
47
|
class ScalarAttrTracker:
|
|
47
48
|
"""ScalarAttrTracker is used to track the occurrence of a pattern's
|
|
49
|
+
|
|
48
50
|
scalar arg/attr in the pattern decomposed graph. Since a scalar attr
|
|
49
51
|
to the pattern can be transformed and turned into a/some ops' scalar
|
|
50
52
|
arg in the decomposed graph, it would be hard to programmatically get
|
|
@@ -57,11 +59,10 @@ class ScalarAttrTracker:
|
|
|
57
59
|
pattern_arg_pos (int): the index of the attr to track in the pattern's
|
|
58
60
|
export_args.
|
|
59
61
|
transform (Callable): the transform function used when targeting the
|
|
60
|
-
occurrence of the attr value in the decomposed graph. An attr value
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
the transformed value back to the original attr value.
|
|
62
|
+
occurrence of the attr value in the decomposed graph. An attr value may be
|
|
63
|
+
transformed during the decomposition and appear as a derived value.
|
|
64
|
+
inverse_transform (Callable): the inverse transform function that maps the
|
|
65
|
+
transformed value back to the original attr value.
|
|
65
66
|
"""
|
|
66
67
|
|
|
67
68
|
attr_name: str
|
|
@@ -74,6 +75,7 @@ class ScalarAttrTracker:
|
|
|
74
75
|
|
|
75
76
|
def track(self, *sources):
|
|
76
77
|
"""Register magic values to track the (transformed) attr values in
|
|
78
|
+
|
|
77
79
|
the pattern decomposed graph.
|
|
78
80
|
"""
|
|
79
81
|
for source in sources:
|
|
@@ -158,24 +160,22 @@ class Pattern:
|
|
|
158
160
|
"""The PyTorch computation pattern to match against a model.
|
|
159
161
|
|
|
160
162
|
Args:
|
|
161
|
-
name (str): the name of the pattern. It would be propagated to
|
|
162
|
-
|
|
163
|
-
|
|
163
|
+
name (str): the name of the pattern. It would be propagated to the `name`
|
|
164
|
+
attr in StableHLO composite ops for the matched model subgraphs in the
|
|
165
|
+
lowering.
|
|
164
166
|
module (torch.nn.Module or Callable): the PyTorch computation.
|
|
165
|
-
export_args (tuple[Any]): the args used to export the pattern module
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
the
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
|
|
178
|
-
The decomposition table to be run on the pattern's exported program.
|
|
167
|
+
export_args (tuple[Any]): the args used to export the pattern module with
|
|
168
|
+
torch.export.export. If export_args contains non-tensor Python scalars,
|
|
169
|
+
there must be a corresponding attr tracker in `scalar_attr_trackers` for
|
|
170
|
+
each scalar arg. attr_builder (Callable[[Pattern, GraphModule,
|
|
171
|
+
InternalMatch], Optional[dict[str, Any]]]): the callable that produces
|
|
172
|
+
the a scalar attrs dict, which would be propagated to `attr` in
|
|
173
|
+
StableHLO composite ops for the matched model subgraphs in the lowering.
|
|
174
|
+
scalar_attr_trackers (list[ScalarAttrTracker]): the trackers for scalar
|
|
175
|
+
args in `export_args`, which are used to track the attr occurrence(s)
|
|
176
|
+
and retrieve their values from the matched subgraph.
|
|
177
|
+
decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]): The
|
|
178
|
+
decomposition table to be run on the pattern's exported program.
|
|
179
179
|
"""
|
|
180
180
|
if not isinstance(module, torch.nn.Module):
|
|
181
181
|
|
|
@@ -12,12 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Tests for mark_pattern."""
|
|
15
16
|
|
|
16
|
-
import
|
|
17
|
-
|
|
17
|
+
from ai_edge_torch import lowertools
|
|
18
18
|
from ai_edge_torch.hlfb import mark_pattern
|
|
19
|
+
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
|
19
20
|
import torch
|
|
20
|
-
|
|
21
|
+
|
|
22
|
+
from tensorflow.python.platform import googletest
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
def _export_stablehlo_mlir(model, args=None):
|
|
@@ -25,11 +27,10 @@ def _export_stablehlo_mlir(model, args=None):
|
|
|
25
27
|
ep = torch.export.export(model, args)
|
|
26
28
|
else:
|
|
27
29
|
ep = model
|
|
28
|
-
|
|
29
|
-
return stablehlo_gm.get_stablehlo_text()
|
|
30
|
+
return lowertools.exported_program_to_mlir_text(ep)
|
|
30
31
|
|
|
31
32
|
|
|
32
|
-
class TestMarkPattern(
|
|
33
|
+
class TestMarkPattern(googletest.TestCase):
|
|
33
34
|
|
|
34
35
|
def test_mark_pattern(self):
|
|
35
36
|
|
|
@@ -38,7 +39,7 @@ class TestMarkPattern(unittest.TestCase):
|
|
|
38
39
|
def forward(self, x):
|
|
39
40
|
return x * x + x + x
|
|
40
41
|
|
|
41
|
-
pattern =
|
|
42
|
+
pattern = pattern_module.Pattern(
|
|
42
43
|
"test.add",
|
|
43
44
|
lambda a, b: a + b,
|
|
44
45
|
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
|
@@ -58,7 +59,7 @@ class TestMarkPattern(unittest.TestCase):
|
|
|
58
59
|
def forward(self, x):
|
|
59
60
|
return x * x * x + x - x * x + x
|
|
60
61
|
|
|
61
|
-
pattern =
|
|
62
|
+
pattern = pattern_module.Pattern(
|
|
62
63
|
"test.add",
|
|
63
64
|
lambda a, b: a + b,
|
|
64
65
|
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
|
@@ -85,12 +86,12 @@ class TestMarkPattern(unittest.TestCase):
|
|
|
85
86
|
r = torch.nn.LogSoftmax(dim=idx % 2)(r) * x
|
|
86
87
|
return r
|
|
87
88
|
|
|
88
|
-
pattern =
|
|
89
|
+
pattern = pattern_module.Pattern(
|
|
89
90
|
"test.log_softmax",
|
|
90
91
|
lambda x, dim: torch.nn.functional.log_softmax(x, dim=dim),
|
|
91
92
|
export_args=(torch.rand(10, 10, 10), 1),
|
|
92
93
|
scalar_attr_trackers=[
|
|
93
|
-
|
|
94
|
+
pattern_module.ScalarAttrTracker("dim", pattern_arg_pos=1)
|
|
94
95
|
.track(0)
|
|
95
96
|
.track(1)
|
|
96
97
|
.track(2),
|
|
@@ -115,7 +116,7 @@ class TestMarkPattern(unittest.TestCase):
|
|
|
115
116
|
z = z + y
|
|
116
117
|
return z
|
|
117
118
|
|
|
118
|
-
pattern =
|
|
119
|
+
pattern = pattern_module.Pattern(
|
|
119
120
|
"test.relu",
|
|
120
121
|
lambda x: torch.ops.aten.relu(x),
|
|
121
122
|
export_args=(torch.rand(2, 2),),
|
|
@@ -131,4 +132,4 @@ class TestMarkPattern(unittest.TestCase):
|
|
|
131
132
|
|
|
132
133
|
|
|
133
134
|
if __name__ == "__main__":
|
|
134
|
-
|
|
135
|
+
googletest.main()
|
|
@@ -12,22 +12,24 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
"""Tests for StableHLOCompositeBuilder."""
|
|
16
|
+
|
|
15
17
|
import math
|
|
16
|
-
import unittest
|
|
17
18
|
|
|
19
|
+
from ai_edge_torch import lowertools
|
|
18
20
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
19
21
|
import torch
|
|
20
22
|
import torch.nn.functional as F
|
|
21
|
-
|
|
23
|
+
|
|
24
|
+
from tensorflow.python.platform import googletest
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
def _export_stablehlo_mlir(model, args):
|
|
25
28
|
ep = torch.export.export(model, args)
|
|
26
|
-
|
|
27
|
-
return stablehlo_gm.get_stablehlo_text()
|
|
29
|
+
return lowertools.exported_program_to_mlir_text(ep)
|
|
28
30
|
|
|
29
31
|
|
|
30
|
-
class TestStableHLOCompositeBuilder(
|
|
32
|
+
class TestStableHLOCompositeBuilder(googletest.TestCase):
|
|
31
33
|
|
|
32
34
|
def test_build_composite(self):
|
|
33
35
|
class SampleModel(torch.nn.Module):
|
|
@@ -273,4 +275,4 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
|
|
|
273
275
|
|
|
274
276
|
|
|
275
277
|
if __name__ == "__main__":
|
|
276
|
-
|
|
278
|
+
googletest.main()
|
|
@@ -13,4 +13,4 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from ._shim import *
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Any, Optional
|
|
17
|
+
|
|
18
|
+
from ai_edge_torch import config
|
|
19
|
+
from ai_edge_torch._convert import signature
|
|
20
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
# isort: off
|
|
24
|
+
if config.Config.use_torch_xla:
|
|
25
|
+
from ai_edge_torch.lowertools import torch_xla_utils as utils
|
|
26
|
+
from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
|
|
27
|
+
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
|
|
28
|
+
from torch_xla.experimental.xla_marker import serialize_composite_attr
|
|
29
|
+
# The following imports are needed to register the needed torch_xla ops.
|
|
30
|
+
import torch_xla.experimental.xla_marker
|
|
31
|
+
import torch_xla.experimental.xla_mlir_debuginfo
|
|
32
|
+
|
|
33
|
+
mark_tensor_op = torch.ops.xla.mark_tensor.default
|
|
34
|
+
write_mlir_debuginfo_op = torch.ops.xla.write_mlir_debuginfo.default
|
|
35
|
+
else:
|
|
36
|
+
from ai_edge_torch.lowertools import odml_torch_utils as utils
|
|
37
|
+
from ai_edge_torch.lowertools.odml_torch_utils import exported_program_to_mlir_text
|
|
38
|
+
from ai_edge_torch.odml_torch.composite import StableHLOCompositeBuilder
|
|
39
|
+
from ai_edge_torch.odml_torch.composite.mark_tensor import serialize_composite_attr
|
|
40
|
+
from ai_edge_torch.odml_torch.composite.mark_tensor import mark_tensor_op
|
|
41
|
+
from ai_edge_torch.odml_torch.debuginfo import write_mlir_debuginfo_op
|
|
42
|
+
# isort: on
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def exported_programs_to_tflite(
|
|
46
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
47
|
+
signatures: list[signature.Signature],
|
|
48
|
+
*,
|
|
49
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
50
|
+
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
|
51
|
+
):
|
|
52
|
+
"""Converts a list of ExportedProgram to a TFLite model.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
exported_programs: A list of ExportedProgram.
|
|
56
|
+
signatures: A list of Signature.
|
|
57
|
+
quant_config: A QuantConfig.
|
|
58
|
+
_tfl_converter_flags: A dict of flags for TFLiteConverter.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
A TFLite model.
|
|
62
|
+
"""
|
|
63
|
+
if _tfl_converter_flags is None:
|
|
64
|
+
_tfl_converter_flags = {}
|
|
65
|
+
|
|
66
|
+
bundles: list[utils.MlirBundle] = [
|
|
67
|
+
utils.exported_program_to_mlir(exported, sig.flat_args)
|
|
68
|
+
for exported, sig in zip(exported_programs, signatures)
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
merged_bundle: utils.MergedBundle = utils.merge_mlir_bundles(
|
|
72
|
+
bundles, signatures, exported_programs
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return utils.merged_bundle_to_tfl_model(
|
|
76
|
+
merged_bundle,
|
|
77
|
+
signatures,
|
|
78
|
+
quant_config=quant_config,
|
|
79
|
+
_tfl_converter_flags=_tfl_converter_flags,
|
|
80
|
+
)
|