ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,151 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Union
18
+
19
+ from ai_edge_torch.generative.quantize import quant_attrs
20
+ from ai_edge_torch.generative.quantize import supported_schemes
21
+
22
+
23
+ @dataclass
24
+ class LayerQuantRecipe:
25
+ """Quantization recipe for a single Edge Generative API layer (e.g. Attention).
26
+
27
+ Generic layer-scoped quantization recipe that specifies how this layer should
28
+ be quantized by the Edge Generative API. This is applicable to layers implemented
29
+ in ai_edge_torch/generative/layers/. Combinations of attributes that are not
30
+ supported during runtime will be detected when .verify() is called.
31
+
32
+ Attributes:
33
+ activation_dtype: Desired data type of activation tensors.
34
+ weight_dtype: Desired data type of weight tensors.
35
+ mode: Type of quantization.
36
+ algorithm: Algorithm for calculating quantization parameters.
37
+ granularity: Granularity of quantization.
38
+ """
39
+
40
+ activation_dtype: quant_attrs.Dtype
41
+ weight_dtype: quant_attrs.Dtype
42
+ mode: quant_attrs.Mode
43
+ algorithm: quant_attrs.Algorithm
44
+ granularity: quant_attrs.Granularity
45
+
46
+ def __str__(self):
47
+ return (
48
+ f'(a:{self.activation_dtype.name}, '
49
+ f'w:{self.weight_dtype.name}, '
50
+ f'{self.mode.name}, '
51
+ f'{self.algorithm.name}, '
52
+ f'{self.granularity.name})'
53
+ )
54
+
55
+ __repr__ = __str__
56
+
57
+ def verify(self):
58
+ """Checks if all attributes configured are supported in runtime.
59
+
60
+ Raises:
61
+ ValueError: If any attributes are incompatible.
62
+ """
63
+ is_valid = False
64
+ for supported in supported_schemes.get_supported_layer_schemes():
65
+ if (
66
+ self.activation_dtype == supported[0]
67
+ and self.weight_dtype == supported[1]
68
+ and self.mode == supported[2]
69
+ and self.algorithm == supported[3]
70
+ and self.granularity == supported[4]
71
+ ):
72
+ is_valid = True
73
+ break
74
+
75
+ if not is_valid:
76
+ raise ValueError(
77
+ 'Unsupported LayerQuantRecipe configuration. See get_supported_recipe_matrix()'
78
+ )
79
+
80
+
81
+ @dataclass
82
+ class GenerativeQuantRecipe:
83
+ """Quantization recipe for a model composed of the Edge Generative API layers.
84
+
85
+ Some layers can be specified with different `LayerQuantRecipe` for each block by
86
+ providing a dictionary keyed by the TransformerBlock index, e.g. attention
87
+ and feedforward. For example,
88
+
89
+ ```
90
+ default = LayerQuantRecipeA
91
+ attention = { 2: LayerQuantRecipeB }
92
+ feedforward = { 3: LayerQuantRecipeC }
93
+ ```
94
+
95
+ will apply LayerQuantRecipeA to the entire model, overriden by
96
+ LayerQuantRecipeB for the TransformerBlock[2].attention layer and
97
+ LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
98
+ with invalid indices will be ignored.
99
+
100
+ Attributes:
101
+ default: The quantization recipe for global scope of the model.
102
+ embedding: Recipe for the embedding table.
103
+ attention: Recipe for the attention blocks. This could be specified with
104
+ different LayerQuantRecipe for each block by providing a dictionary
105
+ keyed by the TransformerBlock index.
106
+ feedforward: Recipe for the feedforward layers. This could be specified with
107
+ different LayerQuantRecipe for each block by providing a dictionary
108
+ keyed by the TransformerBlock index.
109
+ """
110
+
111
+ default: Optional[LayerQuantRecipe] = None
112
+ embedding: Optional[LayerQuantRecipe] = None
113
+ attention: Union[
114
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
115
+ ] = None
116
+ feedforward: Union[
117
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
118
+ ] = None
119
+
120
+ def __str__(self):
121
+ return f"""GenerativeQuantRecipe(
122
+ Default: {self.default}
123
+ Embedding: {self.embedding}
124
+ Attention: {self.attention}
125
+ Feedforward: {self.feedforward}
126
+ )"""
127
+
128
+ __repr__ = __str__
129
+
130
+ def verify(self):
131
+ """Checks if the recipe configured can be supported in runtime.
132
+
133
+ Raises:
134
+ ValueError: If the recipe configured is invalid or unsupported.
135
+ """
136
+ if self.default is not None:
137
+ self.default.verify()
138
+ if self.embedding is not None:
139
+ self.embedding.verify()
140
+ if self.attention is not None:
141
+ if isinstance(self.attention, dict):
142
+ for recipe in self.attention.values():
143
+ recipe.verify()
144
+ else:
145
+ self.attention.verify()
146
+ if self.feedforward is not None:
147
+ if isinstance(self.feedforward, dict):
148
+ for recipe in self.feedforward.values():
149
+ recipe.verify()
150
+ else:
151
+ self.feedforward.verify()
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Helper functions to construct custom quantization recipes.
17
+
18
+ These are intended for more advanced users who want to configure their own
19
+ quantization recipes. For pre-constructed recipes, use `quant_recipes.py` instead.
20
+
21
+ Typical usage example:
22
+
23
+ 1. Applying a single layer recipe to the entire model
24
+
25
+ quant_recipe.GenerativeQuantRecipe(
26
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic()
27
+ )
28
+ """
29
+
30
+ from ai_edge_torch.generative.quantize import quant_attrs
31
+ from ai_edge_torch.generative.quantize import quant_recipe
32
+
33
+
34
+ def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
35
+ return quant_recipe.LayerQuantRecipe(
36
+ activation_dtype=quant_attrs.Dtype.FP32,
37
+ weight_dtype=quant_attrs.Dtype.INT8,
38
+ mode=quant_attrs.Mode.DYNAMIC_RANGE,
39
+ algorithm=quant_attrs.Algorithm.MIN_MAX,
40
+ granularity=quant_attrs.Granularity.CHANNELWISE,
41
+ )
42
+
43
+
44
+ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
45
+ return quant_recipe.LayerQuantRecipe(
46
+ activation_dtype=quant_attrs.Dtype.FP32,
47
+ weight_dtype=quant_attrs.Dtype.FP16,
48
+ mode=quant_attrs.Mode.WEIGHT_ONLY,
49
+ algorithm=quant_attrs.Algorithm.FLOAT_CAST,
50
+ granularity=quant_attrs.Granularity.NONE,
51
+ )
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Helper functions to create common and supported quantization recipes.
17
+
18
+ These recipes will work with models created with the Edge Generative API only.
19
+ Assume Transformer architecture congruent with
20
+ ai_edge_torch/generative/layers/model_config.py:ModelConfig.
21
+
22
+ Typical usage example:
23
+
24
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
25
+ edge_model = ai_edge_torch.convert(
26
+ model, (tokens, input_pos), quant_config=quant_config
27
+ )
28
+ """
29
+
30
+ from ai_edge_torch.generative.quantize import quant_recipe
31
+ from ai_edge_torch.generative.quantize import quant_recipe_utils
32
+ from ai_edge_torch.quantize import quant_config
33
+
34
+
35
+ def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
36
+ return quant_config.QuantConfig(
37
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
38
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
39
+ )
40
+ )
41
+
42
+
43
+ def full_fp16_recipe() -> quant_config.QuantConfig:
44
+ return quant_config.QuantConfig(
45
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
46
+ default=quant_recipe_utils.create_layer_quant_fp16()
47
+ )
48
+ )
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ def get_supported_layer_schemes():
18
+ """List of layer-scoped quantization schemes supported in runtime.
19
+
20
+ Returns:
21
+ List of tuple(activation_dtype, weight_dtype, mode, algorithm, granularity).
22
+ """
23
+ from ai_edge_torch.generative.quantize.quant_attrs import Algorithm as _a
24
+ from ai_edge_torch.generative.quantize.quant_attrs import Dtype as _t
25
+ from ai_edge_torch.generative.quantize.quant_attrs import Granularity as _g
26
+ from ai_edge_torch.generative.quantize.quant_attrs import Mode as _m
27
+
28
+ return [
29
+ (_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
30
+ (_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
31
+ (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
32
+ ]
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Testing weight loader utilities.
16
+
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ import safetensors.torch
22
+ import torch
23
+
24
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
25
+ from ai_edge_torch.generative.utilities import loader as loading_utils
26
+
27
+
28
+ class TestLoader(unittest.TestCase):
29
+ """Unit tests that check weight loader."""
30
+
31
+ def test_load_safetensors(self):
32
+ with tempfile.TemporaryDirectory() as temp_dir:
33
+ file_path = os.path.join(temp_dir, "test.safetensors")
34
+ test_data = {"weight": torch.randn(20, 10), "bias": torch.randn(20)}
35
+ safetensors.torch.save_file(test_data, file_path)
36
+
37
+ loaded_tensors = loading_utils.load_safetensors(file_path)
38
+ self.assertIn("weight", loaded_tensors)
39
+ self.assertIn("bias", loaded_tensors)
40
+
41
+ def test_load_statedict(self):
42
+ with tempfile.TemporaryDirectory() as temp_dir:
43
+ file_path = os.path.join(temp_dir, "test.pt")
44
+ model = torch.nn.Linear(10, 5)
45
+ state_dict = model.state_dict()
46
+ torch.save(state_dict, file_path)
47
+
48
+ loaded_tensors = loading_utils.load_pytorch_statedict(file_path)
49
+ self.assertIn("weight", loaded_tensors)
50
+ self.assertIn("bias", loaded_tensors)
51
+
52
+ def test_model_loader(self):
53
+ with tempfile.TemporaryDirectory() as temp_dir:
54
+ file_path = os.path.join(temp_dir, "test.safetensors")
55
+ test_weights = {
56
+ "lm_head.weight": torch.randn((32000, 2048)),
57
+ "model.embed_tokens.weight": torch.randn((32000, 2048)),
58
+ "model.layers.0.input_layernorm.weight": torch.randn((2048,)),
59
+ "model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
60
+ "model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
61
+ "model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
62
+ "model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
63
+ "model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
64
+ "model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
65
+ "model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
66
+ "model.layers.0.self_attn.v_proj.weight": torch.randn((256, 2048)),
67
+ "model.norm.weight": torch.randn((2048,)),
68
+ }
69
+ safetensors.torch.save_file(test_weights, file_path)
70
+ cfg = tiny_llama.get_model_config()
71
+ cfg.num_layers = 1
72
+ model = tiny_llama.TinyLLamma(cfg)
73
+
74
+ loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
75
+ # if returns successfully, it means all the tensors were initiallized.
76
+ loader.load(model, strict=True)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ unittest.main()
@@ -0,0 +1,235 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Testing model conversion for a few gen-ai models.
16
+ import copy
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ import ai_edge_torch
25
+ from ai_edge_torch.generative.examples.gemma import gemma
26
+ from ai_edge_torch.generative.examples.phi2 import phi2
27
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
28
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
29
+ from ai_edge_torch.testing import model_coverage
30
+
31
+
32
+ class TestModelConversion(unittest.TestCase):
33
+ """Unit tests that check for model conversion and correctness."""
34
+
35
+ def test_toy_model_with_kv_cache(self):
36
+ config = toy_model_with_kv_cache.get_model_config()
37
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
38
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
39
+ [10], dtype=torch.int64
40
+ )
41
+
42
+ edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
43
+
44
+ # TODO(b/338288901): re-enable test to check output tensors.
45
+ skip_output_check = True
46
+ if skip_output_check is False:
47
+ self.assertTrue(
48
+ model_coverage.compare_tflite_torch(
49
+ edge_model,
50
+ pytorch_model,
51
+ (idx, input_pos),
52
+ num_valid_inputs=1,
53
+ atol=1e-5,
54
+ rtol=1e-5,
55
+ )
56
+ )
57
+
58
+ def test_toy_model_with_multi_batches(self):
59
+ config = toy_model_with_kv_cache.get_model_config()
60
+ config.batch_size = 2
61
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
62
+ idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
63
+ [10], dtype=torch.int64
64
+ )
65
+
66
+ edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
67
+
68
+ # TODO(b/338288901): re-enable test to check output tensors.
69
+ skip_output_check = True
70
+ if skip_output_check is False:
71
+ self.assertTrue(
72
+ model_coverage.compare_tflite_torch(
73
+ edge_model,
74
+ pytorch_model,
75
+ (idx, input_pos),
76
+ num_valid_inputs=1,
77
+ atol=1e-5,
78
+ rtol=1e-5,
79
+ )
80
+ )
81
+
82
+ def test_toy_model_with_kv_cache_with_hlfb(self):
83
+ config = toy_model_with_kv_cache.get_model_config()
84
+ config.enable_hlfb = True
85
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
86
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
87
+ [10], dtype=torch.int64
88
+ )
89
+
90
+ edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
91
+
92
+ # TODO(b/338288901): re-enable test to check output tensors.
93
+ skip_output_check = True
94
+ if skip_output_check is False:
95
+ self.assertTrue(
96
+ model_coverage.compare_tflite_torch(
97
+ edge_model,
98
+ pytorch_model,
99
+ (idx, input_pos),
100
+ num_valid_inputs=1,
101
+ atol=1e-5,
102
+ rtol=1e-5,
103
+ )
104
+ )
105
+
106
+ def test_tiny_llama(self):
107
+ self.skipTest("b/338288901")
108
+ config = tiny_llama.get_fake_model_config_for_test()
109
+ pytorch_model = tiny_llama.TinyLLamma(config)
110
+
111
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
112
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
113
+ tokens[0, :4] = idx
114
+ input_pos = torch.arange(0, 10)
115
+
116
+ edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
117
+
118
+ # TODO(b/338288901): re-enable test to check output tensors.
119
+ skip_output_check = True
120
+ if skip_output_check is False:
121
+ self.assertTrue(
122
+ model_coverage.compare_tflite_torch(
123
+ edge_model,
124
+ pytorch_model,
125
+ (tokens, input_pos),
126
+ num_valid_inputs=1,
127
+ atol=1e-5,
128
+ rtol=1e-5,
129
+ )
130
+ )
131
+
132
+ def test_tiny_llama_multisig(self):
133
+ config = tiny_llama.get_fake_model_config_for_test()
134
+ pytorch_model = tiny_llama.TinyLLamma(config)
135
+
136
+ # prefill
137
+ seq_len = 10
138
+ prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
139
+ prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
140
+ prefill_tokens[0, : len(prompt_token)] = prompt_token
141
+ prefill_input_pos = torch.arange(0, seq_len)
142
+
143
+ # decode
144
+ decode_token = torch.tensor([[1]], dtype=torch.long)
145
+ decode_input_pos = torch.tensor([5], dtype=torch.int64)
146
+
147
+ edge_model = (
148
+ ai_edge_torch.signature(
149
+ "prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
150
+ )
151
+ .signature("decode", pytorch_model, (decode_token, decode_input_pos))
152
+ .convert()
153
+ )
154
+
155
+ # TODO(b/338288901): re-enable test to check output tensors.
156
+ skip_output_check = True
157
+ if skip_output_check is False:
158
+ copied_model = copy.deepcopy(pytorch_model)
159
+
160
+ self.assertTrue(
161
+ model_coverage.compare_tflite_torch(
162
+ edge_model,
163
+ pytorch_model,
164
+ (prefill_tokens, prefill_input_pos),
165
+ signature_name="prefill",
166
+ num_valid_inputs=1,
167
+ )
168
+ )
169
+
170
+ self.assertTrue(
171
+ model_coverage.compare_tflite_torch(
172
+ edge_model,
173
+ copied_model,
174
+ (decode_token, decode_input_pos),
175
+ signature_name="decode",
176
+ num_valid_inputs=1,
177
+ )
178
+ )
179
+
180
+ def test_gemma(self):
181
+ self.skipTest("b/338288901")
182
+ config = gemma.get_fake_model_config_2b_for_test()
183
+ model = gemma.Gemma(config)
184
+
185
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
186
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
187
+ tokens[0, :4] = idx
188
+ input_pos = torch.arange(0, 10)
189
+
190
+ edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
191
+
192
+ # TODO(b/338288901): re-enable test to check output tensors.
193
+ skip_output_check = True
194
+ if skip_output_check is False:
195
+ # TODO(talumbau, haoliang): debug numerical diff.
196
+ self.assertTrue(
197
+ model_coverage.compare_tflite_torch(
198
+ edge_model,
199
+ model,
200
+ (tokens, input_pos),
201
+ num_valid_inputs=1,
202
+ atol=1e-2,
203
+ rtol=1e-5,
204
+ )
205
+ )
206
+
207
+ def test_phi2(self):
208
+ self.skipTest("b/338288901")
209
+ config = phi2.get_fake_model_config_for_test()
210
+ pytorch_model = phi2.Phi2(config)
211
+
212
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
213
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
214
+ tokens[0, :4] = idx
215
+ input_pos = torch.arange(0, 10)
216
+
217
+ edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
218
+
219
+ # TODO(b/338288901): re-enable test to check output tensors.
220
+ skip_output_check = True
221
+ if skip_output_check is False:
222
+ self.assertTrue(
223
+ model_coverage.compare_tflite_torch(
224
+ edge_model,
225
+ pytorch_model,
226
+ (tokens, input_pos),
227
+ num_valid_inputs=1,
228
+ atol=1e-5,
229
+ rtol=1e-5,
230
+ )
231
+ )
232
+
233
+
234
+ if __name__ == "__main__":
235
+ unittest.main()