ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -12,19 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A suite of tests to validate experimental external KV Cache layers and models.
16
15
 
17
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
18
- from ai_edge_torch.generative.examples.experimental.phi import phi2
19
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
20
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
16
+ """A suite of tests to validate KV Cache layer."""
17
+
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
19
  import ai_edge_torch.generative.layers.model_config as cfg
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
26
24
 
27
- class TestExternalKVLayers(googletest.TestCase):
25
+ class TestKVLayers(googletest.TestCase):
28
26
 
29
27
  def _get_test_config(
30
28
  self, num_layers, head_dim, num_query_groups, kv_cache_max_len
@@ -32,14 +30,16 @@ class TestExternalKVLayers(googletest.TestCase):
32
30
  attn_config = cfg.AttentionConfig(
33
31
  num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
34
32
  )
33
+ block_config = cfg.TransformerBlockConfig(
34
+ attn_config=attn_config, ff_config=None
35
+ )
35
36
  config = cfg.ModelConfig(
36
37
  kv_cache_max_len=kv_cache_max_len,
37
38
  embedding_dim=head_dim,
38
- attn_config=attn_config,
39
+ block_configs=block_config,
39
40
  num_layers=num_layers,
40
41
  max_seq_len=None,
41
42
  vocab_size=None,
42
- ff_config=None,
43
43
  )
44
44
  return config
45
45
 
@@ -54,7 +54,7 @@ class TestExternalKVLayers(googletest.TestCase):
54
54
  num_query_groups=NUM_QG,
55
55
  kv_cache_max_len=KV_LEN,
56
56
  )
57
- kv = kv_utils.EKVCache.from_model_config(config)
57
+ kv = kv_utils.KVCache.from_model_config(config)
58
58
  entry = kv.caches[0]
59
59
  # single-slice update
60
60
  input_pos = torch.tensor([1])
@@ -88,14 +88,14 @@ class TestExternalKVLayers(googletest.TestCase):
88
88
  def test_serialization(self):
89
89
  class TestModel(torch.nn.Module):
90
90
 
91
- def forward(self, kv: kv_utils.EKVCache) -> kv_utils.EKVCache:
91
+ def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
92
92
  updated_kv_entries = [
93
93
  kv_utils.KVCacheEntry(
94
94
  torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
95
95
  )
96
96
  for entry in kv.caches
97
97
  ]
98
- return kv_utils.EKVCache(updated_kv_entries)
98
+ return kv_utils.KVCache(updated_kv_entries)
99
99
 
100
100
  N = 1
101
101
  HEAD_DIM = 2
@@ -107,7 +107,7 @@ class TestExternalKVLayers(googletest.TestCase):
107
107
  num_query_groups=NUM_QG,
108
108
  kv_cache_max_len=KV_LEN,
109
109
  )
110
- kv = kv_utils.EKVCache.from_model_config(config)
110
+ kv = kv_utils.KVCache.from_model_config(config)
111
111
  model = TestModel()
112
112
  exported_program = torch.export.export(model, (kv,))
113
113
  input_specs = exported_program.graph_signature.input_specs
@@ -116,17 +116,5 @@ class TestExternalKVLayers(googletest.TestCase):
116
116
  self.assertEqual(input_specs[1].arg.name, "kv_v_0")
117
117
 
118
118
 
119
- class TestExternalKVModels(googletest.TestCase):
120
-
121
- def test_can_build_gemma(self):
122
- gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
123
-
124
- def test_can_build_phi2(self):
125
- phi2.define_and_run(checkpoint_path=None, test_model=True)
126
-
127
- def test_can_build_tinyllama(self):
128
- tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
129
-
130
-
131
119
  if __name__ == "__main__":
132
120
  googletest.main()
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
71
71
  safetensors.torch.save_file(test_weights, file_path)
72
72
  cfg = tiny_llama.get_model_config()
73
73
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLLamma(cfg)
74
+ model = tiny_llama.TinyLlama(cfg)
75
75
 
76
76
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
77
  # if returns successfully, it means all the tensors were initiallized.
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
23
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
22
+ from ai_edge_torch.generative.layers import kv_cache
23
+ from ai_edge_torch.generative.test import utils as test_utils
25
24
  import numpy as np
26
25
  import torch
27
26
 
@@ -43,28 +42,32 @@ class TestModelConversion(googletest.TestCase):
43
42
  )
44
43
  )
45
44
 
46
- @googletest.skipIf(
47
- ai_edge_config.Config.use_torch_xla,
48
- reason="tests with custom ops are not supported on oss",
49
- )
50
- def test_toy_model_with_kv_cache(self):
51
- config = toy_model_with_kv_cache.get_model_config()
52
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
53
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
54
- [10], dtype=torch.int64
45
+ def _test_model_with_kv_cache(self, config, pytorch_model):
46
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
47
+ [10], dtype=torch.int
48
+ )
49
+ kv = kv_cache.KVCache.from_model_config(config)
50
+
51
+ edge_model = ai_edge_torch.convert(
52
+ pytorch_model,
53
+ sample_kwargs={
54
+ "tokens": tokens,
55
+ "input_pos": input_pos,
56
+ "kv_cache": kv,
57
+ },
55
58
  )
56
-
57
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
58
59
  edge_model.set_interpreter_builder(
59
60
  self._interpreter_builder(edge_model.tflite_model())
60
61
  )
61
62
 
62
63
  self.assertTrue(
63
- model_coverage.compare_tflite_torch(
64
+ test_utils.compare_tflite_torch(
64
65
  edge_model,
65
66
  pytorch_model,
66
- (idx, input_pos),
67
- num_valid_inputs=1,
67
+ tokens,
68
+ input_pos,
69
+ kv,
70
+ signature_name="serving_default",
68
71
  atol=1e-5,
69
72
  rtol=1e-5,
70
73
  )
@@ -74,83 +77,95 @@ class TestModelConversion(googletest.TestCase):
74
77
  ai_edge_config.Config.use_torch_xla,
75
78
  reason="tests with custom ops are not supported on oss",
76
79
  )
77
- def test_toy_model_with_kv_cache_with_hlfb(self):
80
+ def test_toy_model_with_kv_cache(self):
78
81
  config = toy_model_with_kv_cache.get_model_config()
79
- config.enable_hlfb = True
80
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
81
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
82
- [10], dtype=torch.int64
83
- )
84
-
85
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
86
- edge_model.set_interpreter_builder(
87
- self._interpreter_builder(edge_model.tflite_model())
88
- )
89
-
90
- self.assertTrue(
91
- model_coverage.compare_tflite_torch(
92
- edge_model,
93
- pytorch_model,
94
- (idx, input_pos),
95
- num_valid_inputs=1,
96
- atol=1e-5,
97
- rtol=1e-5,
98
- )
99
- )
82
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
83
+ self._test_model_with_kv_cache(config, pytorch_model)
100
84
 
101
85
  @googletest.skipIf(
102
86
  ai_edge_config.Config.use_torch_xla,
103
87
  reason="tests with custom ops are not supported on oss",
104
88
  )
105
- def test_tiny_llama_multisig(self):
106
- config = tiny_llama.get_fake_model_config()
107
- pytorch_model = tiny_llama.TinyLLamma(config).eval()
89
+ def test_toy_model_with_kv_cache_with_hlfb(self):
90
+ config = toy_model_with_kv_cache.get_model_config()
91
+ config.enable_hlfb = True
92
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
93
+ self._test_model_with_kv_cache(config, pytorch_model)
108
94
 
95
+ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
109
96
  # prefill
110
97
  seq_len = 10
111
- prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
98
+ prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
112
99
  prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
113
100
  prefill_tokens[0, : len(prompt_token)] = prompt_token
114
- prefill_input_pos = torch.arange(0, seq_len)
101
+ prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
115
102
 
116
103
  # decode
117
- decode_token = torch.tensor([[1]], dtype=torch.long)
118
- decode_input_pos = torch.tensor([5], dtype=torch.int64)
104
+ decode_token = torch.tensor([[1]], dtype=torch.int)
105
+ decode_input_pos = torch.tensor([5], dtype=torch.int)
106
+
107
+ kv = kv_cache.KVCache.from_model_config(config)
119
108
 
120
109
  edge_model = (
121
110
  ai_edge_torch.signature(
122
- "prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
111
+ "prefill",
112
+ pytorch_model,
113
+ sample_kwargs={
114
+ "tokens": prefill_tokens,
115
+ "input_pos": prefill_input_pos,
116
+ "kv_cache": kv,
117
+ },
118
+ )
119
+ .signature(
120
+ "decode",
121
+ pytorch_model,
122
+ sample_kwargs={
123
+ "tokens": decode_token,
124
+ "input_pos": decode_input_pos,
125
+ "kv_cache": kv,
126
+ },
123
127
  )
124
- .signature("decode", pytorch_model, (decode_token, decode_input_pos))
125
128
  .convert()
126
129
  )
127
130
  edge_model.set_interpreter_builder(
128
131
  self._interpreter_builder(edge_model.tflite_model())
129
132
  )
130
133
 
131
- copied_model = copy.deepcopy(pytorch_model)
132
- copied_edge = copy.deepcopy(edge_model)
133
-
134
134
  self.assertTrue(
135
- model_coverage.compare_tflite_torch(
135
+ test_utils.compare_tflite_torch(
136
136
  edge_model,
137
137
  pytorch_model,
138
- (prefill_tokens, prefill_input_pos),
138
+ prefill_tokens,
139
+ prefill_input_pos,
140
+ kv,
139
141
  signature_name="prefill",
140
- num_valid_inputs=1,
142
+ atol=atol,
143
+ rtol=atol,
141
144
  )
142
145
  )
143
146
 
144
147
  self.assertTrue(
145
- model_coverage.compare_tflite_torch(
146
- copied_edge,
147
- copied_model,
148
- (decode_token, decode_input_pos),
148
+ test_utils.compare_tflite_torch(
149
+ edge_model,
150
+ pytorch_model,
151
+ decode_token,
152
+ decode_input_pos,
153
+ kv,
149
154
  signature_name="decode",
150
- num_valid_inputs=1,
155
+ atol=atol,
156
+ rtol=atol,
151
157
  )
152
158
  )
153
159
 
160
+ @googletest.skipIf(
161
+ ai_edge_config.Config.use_torch_xla,
162
+ reason="tests with custom ops are not supported on oss",
163
+ )
164
+ def test_tiny_llama_multisig(self):
165
+ config = tiny_llama.get_fake_model_config()
166
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
167
+ self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
+
154
169
 
155
170
  if __name__ == "__main__":
156
171
  googletest.main()
@@ -12,16 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
20
+ from ai_edge_torch.generative.examples.gemma import gemma
21
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.examples.smollm import smollm
25
+ from ai_edge_torch.generative.layers import kv_cache
26
+ from ai_edge_torch.generative.test import utils as test_utils
25
27
  import numpy as np
26
28
  import torch
27
29
 
@@ -43,32 +45,36 @@ class TestModelConversion(googletest.TestCase):
43
45
  )
44
46
  )
45
47
 
46
- @googletest.skipIf(
47
- ai_edge_config.Config.use_torch_xla,
48
- reason="tests with custom ops are not supported on oss",
49
- )
50
- def test_gemma(self):
51
- config = gemma.get_fake_model_config()
52
- model = gemma.Gemma(config)
53
-
48
+ def _test_model(self, config, model, signature_name, atol, rtol):
54
49
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
55
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
50
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
56
51
  tokens[0, :4] = idx
57
- input_pos = torch.arange(0, 10)
52
+ input_pos = torch.arange(0, 10, dtype=torch.int)
53
+ kv = kv_cache.KVCache.from_model_config(config)
58
54
 
59
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
55
+ edge_model = ai_edge_torch.signature(
56
+ signature_name,
57
+ model,
58
+ sample_kwargs={
59
+ "tokens": tokens,
60
+ "input_pos": input_pos,
61
+ "kv_cache": kv,
62
+ },
63
+ ).convert()
60
64
  edge_model.set_interpreter_builder(
61
65
  self._interpreter_builder(edge_model.tflite_model())
62
66
  )
63
67
 
64
68
  self.assertTrue(
65
- model_coverage.compare_tflite_torch(
69
+ test_utils.compare_tflite_torch(
66
70
  edge_model,
67
71
  model,
68
- (tokens, input_pos),
69
- num_valid_inputs=1,
70
- atol=1e-2,
71
- rtol=1e-5,
72
+ tokens,
73
+ input_pos,
74
+ kv,
75
+ signature_name=signature_name,
76
+ atol=atol,
77
+ rtol=rtol,
72
78
  )
73
79
  )
74
80
 
@@ -76,34 +82,21 @@ class TestModelConversion(googletest.TestCase):
76
82
  ai_edge_config.Config.use_torch_xla,
77
83
  reason="tests with custom ops are not supported on oss",
78
84
  )
79
- def test_gemma2(self):
80
- config = gemma2.get_fake_model_config()
81
- model = gemma2.Gemma2(config)
82
- model.eval()
83
-
84
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
85
- prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
- prefill_tokens[0, :4] = idx
87
- prefill_input_pos = torch.arange(0, 10)
88
-
89
- edge_model = ai_edge_torch.signature(
90
- "prefill", model, (prefill_tokens, prefill_input_pos)
91
- ).convert()
92
- edge_model.set_interpreter_builder(
93
- self._interpreter_builder(edge_model.tflite_model())
85
+ def test_gemma(self):
86
+ config = gemma.get_fake_model_config()
87
+ pytorch_model = gemma.Gemma(config).eval()
88
+ self._test_model(
89
+ config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
94
90
  )
95
91
 
96
- self.assertTrue(
97
- model_coverage.compare_tflite_torch(
98
- edge_model,
99
- model,
100
- (prefill_tokens, prefill_input_pos),
101
- signature_name="prefill",
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
105
- )
106
- )
92
+ @googletest.skipIf(
93
+ ai_edge_config.Config.use_torch_xla,
94
+ reason="tests with custom ops are not supported on oss",
95
+ )
96
+ def test_gemma2(self):
97
+ config = gemma2.get_fake_model_config()
98
+ pytorch_model = gemma2.Gemma2(config).eval()
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
107
100
 
108
101
  @googletest.skipIf(
109
102
  ai_edge_config.Config.use_torch_xla,
@@ -112,27 +105,27 @@ class TestModelConversion(googletest.TestCase):
112
105
  def test_phi2(self):
113
106
  config = phi2.get_fake_model_config()
114
107
  pytorch_model = phi2.Phi2(config).eval()
115
-
116
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
118
- tokens[0, :4] = idx
119
- input_pos = torch.arange(0, 10)
120
-
121
- edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
122
- edge_model.set_interpreter_builder(
123
- self._interpreter_builder(edge_model.tflite_model())
108
+ self._test_model(
109
+ config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
124
110
  )
125
111
 
126
- self.assertTrue(
127
- model_coverage.compare_tflite_torch(
128
- edge_model,
129
- pytorch_model,
130
- (tokens, input_pos),
131
- num_valid_inputs=1,
132
- atol=1e-3,
133
- rtol=1e-3,
134
- )
135
- )
112
+ @googletest.skipIf(
113
+ ai_edge_config.Config.use_torch_xla,
114
+ reason="tests with custom ops are not supported on oss",
115
+ )
116
+ def test_smollm(self):
117
+ config = smollm.get_fake_model_config()
118
+ pytorch_model = smollm.SmolLM(config).eval()
119
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
120
+
121
+ @googletest.skipIf(
122
+ ai_edge_config.Config.use_torch_xla,
123
+ reason="tests with custom ops are not supported on oss",
124
+ )
125
+ def test_openelm(self):
126
+ config = openelm.get_fake_model_config()
127
+ pytorch_model = openelm.OpenELM(config).eval()
128
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
136
129
 
137
130
 
138
131
  if __name__ == "__main__":
@@ -115,8 +115,8 @@ class TestQuantizeConvert(parameterized.TestCase):
115
115
  def test_quantize_convert_toy_sizes(self, quant_config):
116
116
  config = toy_model.get_model_config()
117
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
118
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
- input_pos = torch.arange(0, 100)
118
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
119
+ input_pos = torch.arange(0, 100, dtype=torch.int)
120
120
 
121
121
  quantized_model = ai_edge_torch.convert(
122
122
  pytorch_model, (idx, input_pos), quant_config=quant_config
@@ -131,8 +131,8 @@ class TestQuantizeConvert(parameterized.TestCase):
131
131
  def test_quantize_convert_toy_weight_sharing(self):
132
132
  config = toy_model.get_model_config()
133
133
  pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
135
- input_pos = torch.arange(0, 100)
134
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
135
+ input_pos = torch.arange(0, 100, dtype=torch.int)
136
136
 
137
137
  quant_config = quant_recipes.full_int8_dynamic_recipe()
138
138
  quantized_model = ai_edge_torch.convert(
@@ -149,7 +149,7 @@ class TestQuantizeConvert(parameterized.TestCase):
149
149
  self.skipTest("b/338288901")
150
150
  config = toy_model_with_kv_cache.get_model_config()
151
151
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
152
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
152
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
153
153
  [10], dtype=torch.int64
154
154
  )
155
155
 
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utils for testing."""
17
+
18
+ from ai_edge_torch import model
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.lowertools import common_utils
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils import _pytree as pytree
24
+
25
+
26
+ def compare_tflite_torch(
27
+ edge_model: model.Model,
28
+ torch_model: torch.nn.Module,
29
+ tokens: torch.Tensor,
30
+ input_pos: torch.Tensor,
31
+ kv_cache: kv_utils.KVCache,
32
+ signature_name: str,
33
+ atol: float = 1e-5,
34
+ rtol: float = 1e-5,
35
+ ):
36
+ """Compares torch models and TFLite models."""
37
+ values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
+ flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
+ torch_output = torch_model(tokens, input_pos, kv_cache)
40
+
41
+ input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ edge_output = edge_model(
43
+ signature_name=signature_name,
44
+ tokens=tokens.numpy(),
45
+ input_pos=input_pos.numpy(),
46
+ **input_kv_flatten,
47
+ )
48
+
49
+ return np.allclose(
50
+ edge_output["logits"],
51
+ torch_output["logits"].detach().numpy(),
52
+ atol=atol,
53
+ rtol=rtol,
54
+ )