ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,10 @@
14
14
  # ==============================================================================
15
15
  # Common normalization layers.
16
16
 
17
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
17
18
  import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
18
21
 
19
22
 
20
23
  # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
58
61
  return output * (1 + self.weight)
59
62
  else:
60
63
  return output * self.weight
64
+
65
+
66
+ class GroupNorm(torch.nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ group_num: int,
71
+ dim: int,
72
+ eps: float = 1e-5,
73
+ enable_hlfb: bool = False,
74
+ ):
75
+ """Initialize the GroupNorm layer.
76
+
77
+ Args:
78
+ group_num (int): Number of groups to separate the channels into.
79
+ dim (int): Dimension of the input tensor.
80
+ eps (float): A small float value to ensure numerical stability (default:
81
+ 1e-6).
82
+ enable_hlfb (bool): Whether to convert this normalization into a single
83
+ op.
84
+ """
85
+ super().__init__()
86
+ self.enable_hlfb = enable_hlfb
87
+ self.group_num = group_num
88
+ self.eps = eps
89
+ self.weight = torch.nn.Parameter(torch.ones(dim))
90
+ self.bias = torch.nn.Parameter(torch.ones(dim))
91
+
92
+ def forward(self, x):
93
+ """Running the forward pass of GroupNorm layer.
94
+
95
+ Args:
96
+ x (torch.Tensor): input tensor.
97
+
98
+ Returns:
99
+ torch.Tensor: output tensor after applying GroupNorm.
100
+ """
101
+ if self.enable_hlfb:
102
+ return group_norm_with_hlfb(
103
+ x,
104
+ self.weight,
105
+ self.bias,
106
+ self.group_num,
107
+ self.eps,
108
+ )
109
+ else:
110
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
111
+
112
+
113
+ class LayerNorm(torch.nn.Module):
114
+
115
+ def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
116
+ """Initialize the LayerNorm layer.
117
+
118
+ Args:
119
+ dim (int): dimension of the input tensor.
120
+ eps (float): A small float value to ensure numerical stability (default:
121
+ 1e-6).
122
+ enable_hlfb (bool): Whether to convert this normalization into a single
123
+ op.
124
+ """
125
+ super().__init__()
126
+ self.enable_hlfb = enable_hlfb
127
+ self.eps = eps
128
+ self.weight = torch.nn.Parameter(torch.ones(dim))
129
+ self.bias = torch.nn.Parameter(torch.ones(dim))
130
+
131
+ def forward(self, x):
132
+ """Running the forward pass of LayerNorm layer.
133
+
134
+ Args:
135
+ x (torch.Tensor): input tensor.
136
+
137
+ Returns:
138
+ torch.Tensor: output tensor after applying LayerNorm.
139
+ """
140
+ if self.enable_hlfb:
141
+ return layer_norm_with_hlfb(
142
+ x,
143
+ self.weight,
144
+ self.bias,
145
+ self.eps,
146
+ )
147
+ else:
148
+ return F.layer_norm(
149
+ x,
150
+ x.shape,
151
+ self.weight.broadcast_to(x.shape),
152
+ self.bias.broadcast_to(x.shape),
153
+ self.eps,
154
+ )
155
+
156
+
157
+ def group_norm_with_hlfb(
158
+ x: torch.Tensor,
159
+ w: torch.Tensor,
160
+ b: torch.Tensor,
161
+ num_groups: int,
162
+ eps: float,
163
+ ):
164
+ """Group Normalization with high-level function boundary enabled.
165
+
166
+ Args:
167
+ x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
168
+ w (torch.Tensor): The weight tensor for the normalization.
169
+ b (torch.Tensor): The bias tensor for the normalization.
170
+ num_groups (int): Number of groups to separate the channels into.
171
+ eps (float): A small float value to ensure numerical stability.
172
+
173
+ Returns:
174
+ The output tensor of Group Normalization.
175
+ """
176
+ x = torch.permute(x, (0, 2, 3, 1))
177
+
178
+ builder = StableHLOCompositeBuilder(
179
+ name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
180
+ )
181
+ x, w, b = builder.mark_inputs(x, w, b)
182
+ x = torch.permute(x, (0, 3, 1, 2))
183
+ y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
184
+ y = torch.permute(y, (0, 2, 3, 1))
185
+ y = builder.mark_outputs(y)
186
+
187
+ y = torch.permute(y, (0, 3, 1, 2))
188
+ return y
189
+
190
+
191
+ def layer_norm_with_hlfb(
192
+ x: torch.Tensor,
193
+ w: torch.Tensor,
194
+ b: torch.Tensor,
195
+ eps: float,
196
+ ):
197
+ """Layer Normalization with high-level function boundary enabled.
198
+
199
+ Args:
200
+ x (torch.Tensor): Input tensor for Layer Normalization.
201
+ w (torch.Tensor): The weight tensor for the normalization.
202
+ b (torch.Tensor): The bias tensor for the normalization.
203
+ eps (float): A small float value to ensure numerical stability.
204
+
205
+ Returns:
206
+ The output tensor of Layer Normalization.
207
+ """
208
+ builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
+ x, w, b = builder.mark_inputs(x, w, b)
210
+ y = F.layer_norm(
211
+ x,
212
+ x.shape,
213
+ weight=w.broadcast_to(x.shape),
214
+ bias=b.broadcast_to(x.shape),
215
+ eps=eps,
216
+ )
217
+ y = builder.mark_outputs(y)
218
+ return y
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
122
122
  config.attention_batch_size,
123
123
  config.dim,
124
124
  config.attention_config,
125
- 0,
126
125
  enable_hlfb=config.enable_hlfb,
127
126
  )
128
127
 
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
180
179
  config.query_dim,
181
180
  config.cross_dim,
182
181
  config.attention_config,
183
- 0,
184
182
  enable_hlfb=config.enable_hlfb,
185
183
  )
186
184
 
@@ -12,19 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A suite of tests to validate experimental external KV Cache layers and models.
16
15
 
17
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
18
- from ai_edge_torch.generative.examples.experimental.phi import phi2
19
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
20
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
16
+ """A suite of tests to validate KV Cache layer."""
17
+
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
19
  import ai_edge_torch.generative.layers.model_config as cfg
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
26
24
 
27
- class TestExternalKVLayers(googletest.TestCase):
25
+ class TestKVLayers(googletest.TestCase):
28
26
 
29
27
  def _get_test_config(
30
28
  self, num_layers, head_dim, num_query_groups, kv_cache_max_len
@@ -32,14 +30,16 @@ class TestExternalKVLayers(googletest.TestCase):
32
30
  attn_config = cfg.AttentionConfig(
33
31
  num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
34
32
  )
33
+ block_config = cfg.TransformerBlockConfig(
34
+ attn_config=attn_config, ff_config=None
35
+ )
35
36
  config = cfg.ModelConfig(
36
37
  kv_cache_max_len=kv_cache_max_len,
37
38
  embedding_dim=head_dim,
38
- attn_config=attn_config,
39
+ block_configs=block_config,
39
40
  num_layers=num_layers,
40
41
  max_seq_len=None,
41
42
  vocab_size=None,
42
- ff_config=None,
43
43
  )
44
44
  return config
45
45
 
@@ -54,7 +54,7 @@ class TestExternalKVLayers(googletest.TestCase):
54
54
  num_query_groups=NUM_QG,
55
55
  kv_cache_max_len=KV_LEN,
56
56
  )
57
- kv = kv_utils.EKVCache.from_model_config(config)
57
+ kv = kv_utils.KVCache.from_model_config(config)
58
58
  entry = kv.caches[0]
59
59
  # single-slice update
60
60
  input_pos = torch.tensor([1])
@@ -88,14 +88,14 @@ class TestExternalKVLayers(googletest.TestCase):
88
88
  def test_serialization(self):
89
89
  class TestModel(torch.nn.Module):
90
90
 
91
- def forward(self, kv: kv_utils.EKVCache) -> kv_utils.EKVCache:
91
+ def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
92
92
  updated_kv_entries = [
93
93
  kv_utils.KVCacheEntry(
94
94
  torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
95
95
  )
96
96
  for entry in kv.caches
97
97
  ]
98
- return kv_utils.EKVCache(updated_kv_entries)
98
+ return kv_utils.KVCache(updated_kv_entries)
99
99
 
100
100
  N = 1
101
101
  HEAD_DIM = 2
@@ -107,7 +107,7 @@ class TestExternalKVLayers(googletest.TestCase):
107
107
  num_query_groups=NUM_QG,
108
108
  kv_cache_max_len=KV_LEN,
109
109
  )
110
- kv = kv_utils.EKVCache.from_model_config(config)
110
+ kv = kv_utils.KVCache.from_model_config(config)
111
111
  model = TestModel()
112
112
  exported_program = torch.export.export(model, (kv,))
113
113
  input_specs = exported_program.graph_signature.input_specs
@@ -116,17 +116,5 @@ class TestExternalKVLayers(googletest.TestCase):
116
116
  self.assertEqual(input_specs[1].arg.name, "kv_v_0")
117
117
 
118
118
 
119
- class TestExternalKVModels(googletest.TestCase):
120
-
121
- def test_can_build_gemma(self):
122
- gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
123
-
124
- def test_can_build_phi2(self):
125
- phi2.define_and_run(checkpoint_path=None, test_model=True)
126
-
127
- def test_can_build_tinyllama(self):
128
- tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
129
-
130
-
131
119
  if __name__ == "__main__":
132
120
  googletest.main()
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
71
71
  safetensors.torch.save_file(test_weights, file_path)
72
72
  cfg = tiny_llama.get_model_config()
73
73
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLLamma(cfg)
74
+ model = tiny_llama.TinyLlama(cfg)
75
75
 
76
76
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
77
  # if returns successfully, it means all the tensors were initiallized.
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
23
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
22
+ from ai_edge_torch.generative.layers import kv_cache
23
+ from ai_edge_torch.generative.test import utils as test_utils
25
24
  import numpy as np
26
25
  import torch
27
26
 
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
49
48
  )
50
49
  def test_toy_model_with_kv_cache(self):
51
50
  config = toy_model_with_kv_cache.get_model_config()
52
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
53
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
51
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
52
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
54
53
  [10], dtype=torch.int64
55
54
  )
56
-
57
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
55
+ kv = kv_cache.KVCache.from_model_config(config)
56
+
57
+ edge_model = ai_edge_torch.convert(
58
+ pytorch_model,
59
+ sample_kwargs={
60
+ "tokens": tokens,
61
+ "input_pos": input_pos,
62
+ "kv_cache": kv,
63
+ },
64
+ )
58
65
  edge_model.set_interpreter_builder(
59
66
  self._interpreter_builder(edge_model.tflite_model())
60
67
  )
61
68
 
62
69
  self.assertTrue(
63
- model_coverage.compare_tflite_torch(
70
+ test_utils.compare_tflite_torch(
64
71
  edge_model,
65
72
  pytorch_model,
66
- (idx, input_pos),
67
- num_valid_inputs=1,
73
+ tokens,
74
+ input_pos,
75
+ kv,
76
+ signature_name="serving_default",
68
77
  atol=1e-5,
69
78
  rtol=1e-5,
70
79
  )
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
77
86
  def test_toy_model_with_kv_cache_with_hlfb(self):
78
87
  config = toy_model_with_kv_cache.get_model_config()
79
88
  config.enable_hlfb = True
80
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
81
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
89
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
90
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
82
91
  [10], dtype=torch.int64
83
92
  )
84
-
85
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
93
+ kv = kv_cache.KVCache.from_model_config(config)
94
+
95
+ edge_model = ai_edge_torch.convert(
96
+ pytorch_model,
97
+ sample_kwargs={
98
+ "tokens": tokens,
99
+ "input_pos": input_pos,
100
+ "kv_cache": kv,
101
+ },
102
+ )
86
103
  edge_model.set_interpreter_builder(
87
104
  self._interpreter_builder(edge_model.tflite_model())
88
105
  )
89
106
 
90
107
  self.assertTrue(
91
- model_coverage.compare_tflite_torch(
108
+ test_utils.compare_tflite_torch(
92
109
  edge_model,
93
110
  pytorch_model,
94
- (idx, input_pos),
95
- num_valid_inputs=1,
111
+ tokens,
112
+ input_pos,
113
+ kv,
114
+ signature_name="serving_default",
96
115
  atol=1e-5,
97
116
  rtol=1e-5,
98
117
  )
@@ -104,7 +123,7 @@ class TestModelConversion(googletest.TestCase):
104
123
  )
105
124
  def test_tiny_llama_multisig(self):
106
125
  config = tiny_llama.get_fake_model_config()
107
- pytorch_model = tiny_llama.TinyLLamma(config).eval()
126
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
108
127
 
109
128
  # prefill
110
129
  seq_len = 10
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
117
136
  decode_token = torch.tensor([[1]], dtype=torch.long)
118
137
  decode_input_pos = torch.tensor([5], dtype=torch.int64)
119
138
 
139
+ kv = kv_cache.KVCache.from_model_config(config)
140
+
120
141
  edge_model = (
121
142
  ai_edge_torch.signature(
122
- "prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
143
+ "prefill",
144
+ pytorch_model,
145
+ sample_kwargs={
146
+ "tokens": prefill_tokens,
147
+ "input_pos": prefill_input_pos,
148
+ "kv_cache": kv,
149
+ },
150
+ )
151
+ .signature(
152
+ "decode",
153
+ pytorch_model,
154
+ sample_kwargs={
155
+ "tokens": decode_token,
156
+ "input_pos": decode_input_pos,
157
+ "kv_cache": kv,
158
+ },
123
159
  )
124
- .signature("decode", pytorch_model, (decode_token, decode_input_pos))
125
160
  .convert()
126
161
  )
127
162
  edge_model.set_interpreter_builder(
128
163
  self._interpreter_builder(edge_model.tflite_model())
129
164
  )
130
165
 
131
- copied_model = copy.deepcopy(pytorch_model)
132
- copied_edge = copy.deepcopy(edge_model)
133
-
134
166
  self.assertTrue(
135
- model_coverage.compare_tflite_torch(
167
+ test_utils.compare_tflite_torch(
136
168
  edge_model,
137
169
  pytorch_model,
138
- (prefill_tokens, prefill_input_pos),
170
+ prefill_tokens,
171
+ prefill_input_pos,
172
+ kv,
139
173
  signature_name="prefill",
140
- num_valid_inputs=1,
174
+ atol=1e-5,
175
+ rtol=1e-5,
141
176
  )
142
177
  )
143
178
 
144
179
  self.assertTrue(
145
- model_coverage.compare_tflite_torch(
146
- copied_edge,
147
- copied_model,
148
- (decode_token, decode_input_pos),
180
+ test_utils.compare_tflite_torch(
181
+ edge_model,
182
+ pytorch_model,
183
+ decode_token,
184
+ decode_input_pos,
185
+ kv,
149
186
  signature_name="decode",
150
- num_valid_inputs=1,
187
+ atol=1e-5,
188
+ rtol=1e-5,
151
189
  )
152
190
  )
153
191
 
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
20
+ from ai_edge_torch.generative.examples.gemma import gemma
21
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.test import utils as test_utils
25
25
  import numpy as np
26
26
  import torch
27
27
 
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
55
55
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
56
56
  tokens[0, :4] = idx
57
57
  input_pos = torch.arange(0, 10)
58
-
59
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
58
+ kv = kv_cache.KVCache.from_model_config(config)
59
+
60
+ edge_model = ai_edge_torch.convert(
61
+ model,
62
+ sample_kwargs={
63
+ "tokens": tokens,
64
+ "input_pos": input_pos,
65
+ "kv_cache": kv,
66
+ },
67
+ )
60
68
  edge_model.set_interpreter_builder(
61
69
  self._interpreter_builder(edge_model.tflite_model())
62
70
  )
63
71
 
64
72
  self.assertTrue(
65
- model_coverage.compare_tflite_torch(
73
+ test_utils.compare_tflite_torch(
66
74
  edge_model,
67
75
  model,
68
- (tokens, input_pos),
69
- num_valid_inputs=1,
76
+ tokens,
77
+ input_pos,
78
+ kv,
79
+ signature_name="serving_default",
70
80
  atol=1e-2,
71
81
  rtol=1e-5,
72
82
  )
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
85
95
  prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
96
  prefill_tokens[0, :4] = idx
87
97
  prefill_input_pos = torch.arange(0, 10)
98
+ kv = kv_cache.KVCache.from_model_config(config)
88
99
 
89
100
  edge_model = ai_edge_torch.signature(
90
- "prefill", model, (prefill_tokens, prefill_input_pos)
101
+ "prefill",
102
+ model,
103
+ sample_kwargs={
104
+ "tokens": prefill_tokens,
105
+ "input_pos": prefill_input_pos,
106
+ "kv_cache": kv,
107
+ },
91
108
  ).convert()
92
109
  edge_model.set_interpreter_builder(
93
110
  self._interpreter_builder(edge_model.tflite_model())
94
111
  )
95
112
 
96
113
  self.assertTrue(
97
- model_coverage.compare_tflite_torch(
114
+ test_utils.compare_tflite_torch(
98
115
  edge_model,
99
116
  model,
100
- (prefill_tokens, prefill_input_pos),
117
+ prefill_tokens,
118
+ prefill_input_pos,
119
+ kv,
101
120
  signature_name="prefill",
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
121
+ atol=1e-1,
122
+ rtol=1e-3,
105
123
  )
106
124
  )
107
125
 
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
117
135
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
118
136
  tokens[0, :4] = idx
119
137
  input_pos = torch.arange(0, 10)
120
-
121
- edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
138
+ kv = kv_cache.KVCache.from_model_config(config)
139
+
140
+ edge_model = ai_edge_torch.convert(
141
+ pytorch_model,
142
+ sample_kwargs={
143
+ "tokens": tokens,
144
+ "input_pos": input_pos,
145
+ "kv_cache": kv,
146
+ },
147
+ )
122
148
  edge_model.set_interpreter_builder(
123
149
  self._interpreter_builder(edge_model.tflite_model())
124
150
  )
125
151
 
126
152
  self.assertTrue(
127
- model_coverage.compare_tflite_torch(
153
+ test_utils.compare_tflite_torch(
128
154
  edge_model,
129
155
  pytorch_model,
130
- (tokens, input_pos),
131
- num_valid_inputs=1,
156
+ tokens,
157
+ input_pos,
158
+ kv,
159
+ signature_name="serving_default",
132
160
  atol=1e-3,
133
161
  rtol=1e-3,
134
162
  )
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utils for testing."""
17
+
18
+ from ai_edge_torch import model
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.lowertools import common_utils
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils import _pytree as pytree
24
+
25
+
26
+ def compare_tflite_torch(
27
+ edge_model: model.Model,
28
+ torch_model: torch.nn.Module,
29
+ tokens: torch.Tensor,
30
+ input_pos: torch.Tensor,
31
+ kv_cache: kv_utils.KVCache,
32
+ signature_name: str,
33
+ atol: float = 1e-5,
34
+ rtol: float = 1e-5,
35
+ ):
36
+ """Compares torch models and TFLite models."""
37
+ values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
+ flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
+ torch_output = torch_model(tokens, input_pos, kv_cache)
40
+
41
+ input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ edge_output = edge_model(
43
+ signature_name=signature_name,
44
+ tokens=tokens.numpy(),
45
+ input_pos=input_pos.numpy(),
46
+ **input_kv_flatten,
47
+ )
48
+
49
+ return np.allclose(
50
+ edge_output["logits"],
51
+ torch_output["logits"].detach().numpy(),
52
+ atol=atol,
53
+ rtol=rtol,
54
+ )