ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (27) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -3
  12. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +40 -32
  13. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  14. ai_edge_torch/generative/layers/model_config.py +6 -0
  15. ai_edge_torch/generative/test/test_loader.py +2 -1
  16. ai_edge_torch/generative/test/test_model_conversion.py +39 -17
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  18. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  19. ai_edge_torch/lowertools/translate_recipe.py +2 -2
  20. ai_edge_torch/version.py +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/METADATA +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/RECORD +25 -26
  23. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  24. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  25. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/LICENSE +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/WHEEL +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,10 @@
15
15
 
16
16
  """Example of building Qwen 2.5 models."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # Qwen re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class Qwen(tiny_llama.TinyLlama):
31
- """A Qwen model built from the Edge Generative API layers.
32
-
33
- Qwen 2.5 shares the same architecture as TinyLlama.
34
- """
19
+ from ai_edge_torch.generative.utilities import model_builder
35
20
 
36
- def __init__(self, config: cfg.ModelConfig):
37
- super().__init__(config)
38
- # Qwen re-uses the embedding as the head projection layer.
39
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
40
22
 
41
23
 
42
24
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
119
101
  return config
120
102
 
121
103
 
122
- def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
- model = Qwen(config)
124
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
- # Since embedding and lm-head use the same weight, we need to set strict
126
- # to False.
127
- loader.load(model, strict=False)
128
- model.eval()
129
- return model
130
-
131
-
132
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
- return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
104
+ def build_3b_model(
105
+ checkpoint_path: str, **kwargs
106
+ ) -> model_builder.DecoderOnlyModel:
107
+ return model_builder.build_decoder_only_model(
108
+ checkpoint_path=checkpoint_path,
109
+ config=get_3b_model_config(**kwargs),
110
+ tensor_names=TENSOR_NAMES,
111
+ )
134
112
 
135
113
 
136
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
- return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
114
+ def build_1_5b_model(
115
+ checkpoint_path: str, **kwargs
116
+ ) -> model_builder.DecoderOnlyModel:
117
+ return model_builder.build_decoder_only_model(
118
+ checkpoint_path=checkpoint_path,
119
+ config=get_1_5b_model_config(**kwargs),
120
+ tensor_names=TENSOR_NAMES,
121
+ )
138
122
 
139
123
 
140
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
- return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
124
+ def build_0_5b_model(
125
+ checkpoint_path: str, **kwargs
126
+ ) -> model_builder.DecoderOnlyModel:
127
+ return model_builder.build_decoder_only_model(
128
+ checkpoint_path=checkpoint_path,
129
+ config=get_0_5b_model_config(**kwargs),
130
+ tensor_names=TENSOR_NAMES,
131
+ )
@@ -15,29 +15,10 @@
15
15
 
16
16
  """Example of building a SmolLM model."""
17
17
 
18
- import copy
19
-
20
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
18
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.utilities.loader as loading_utils
23
- from torch import nn
24
-
25
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
- # SmolLM re-uses the embedding as the head projection layer.
27
- TENSOR_NAMES.lm_head = None
28
-
29
-
30
- class SmolLM(tiny_llama.TinyLlama):
31
- """A SmolLM model built from the Edge Generative API layers.
19
+ from ai_edge_torch.generative.utilities import model_builder
32
20
 
33
- SmolLM shares the same architecture as TinyLlama, but with different model
34
- sizes.
35
- """
36
-
37
- def __init__(self, config: cfg.ModelConfig):
38
- super().__init__(config)
39
- # SmolLM re-uses the embedding as the head projection layer.
40
- self.lm_head.weight.data = self.tok_embedding.weight.data
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
41
22
 
42
23
 
43
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
91
72
  return config
92
73
 
93
74
 
94
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
95
- config = get_model_config(**kwargs)
96
- model = SmolLM(config)
97
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
98
- # Since embedding and lm-head use the same weight, we need to set strict
99
- # to False.
100
- loader.load(model, strict=False)
101
- model.eval()
102
- return model
75
+ def build_model(
76
+ checkpoint_path: str, **kwargs
77
+ ) -> model_builder.DecoderOnlyModel:
78
+ return model_builder.build_decoder_only_model(
79
+ checkpoint_path=checkpoint_path,
80
+ config=get_model_config(**kwargs),
81
+ tensor_names=TENSOR_NAMES,
82
+ )
@@ -75,9 +75,7 @@ class CLIP(nn.Module):
75
75
  )
76
76
 
77
77
  @torch.inference_mode
78
- def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
79
- tokens = tokens.type(torch.int)
80
-
78
+ def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
81
79
  state = self.tok_embedding(tokens) + self.tok_embedding_position
82
80
  for layer in self.transformer_blocks:
83
81
  state = layer(state, mask=self.mask_cache)
@@ -13,47 +13,54 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import argparse
17
16
  import os
18
- from pathlib import Path
19
- from typing import Optional
17
+ import pathlib
20
18
 
19
+ from absl import app
20
+ from absl import flags
21
21
  import ai_edge_torch
22
- import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
23
- import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
24
- import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
25
- from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
- import ai_edge_torch.generative.examples.stable_diffusion.util as util
22
+ from ai_edge_torch.generative.examples.stable_diffusion import clip
23
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder
24
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion
25
+ from ai_edge_torch.generative.examples.stable_diffusion import util
27
26
  from ai_edge_torch.generative.quantize import quant_recipes
28
- import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
27
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
29
28
  import torch
30
29
 
31
- arg_parser = argparse.ArgumentParser()
32
- arg_parser.add_argument(
33
- '--clip_ckpt',
34
- type=str,
30
+ _CLIP_CKPT = flags.DEFINE_string(
31
+ 'clip_ckpt',
32
+ None,
35
33
  help='Path to source CLIP model checkpoint',
36
34
  required=True,
37
35
  )
38
- arg_parser.add_argument(
39
- '--diffusion_ckpt',
40
- type=str,
36
+
37
+ _DIFFUSION_CKPT = flags.DEFINE_string(
38
+ 'diffusion_ckpt',
39
+ None,
41
40
  help='Path to source diffusion model checkpoint',
42
41
  required=True,
43
42
  )
44
- arg_parser.add_argument(
45
- '--decoder_ckpt',
46
- type=str,
43
+
44
+ _DECODER_CKPT = flags.DEFINE_string(
45
+ 'decoder_ckpt',
46
+ None,
47
47
  help='Path to source image decoder model checkpoint',
48
48
  required=True,
49
49
  )
50
- arg_parser.add_argument(
51
- '--output_dir',
52
- type=str,
50
+
51
+ _OUTPUT_DIR = flags.DEFINE_string(
52
+ 'output_dir',
53
+ None,
53
54
  help='Path to the converted TF Lite directory.',
54
55
  required=True,
55
56
  )
56
57
 
58
+ _QUANTIZE = flags.DEFINE_bool(
59
+ 'quantize',
60
+ help='Whether to quantize the model during conversion.',
61
+ default=True,
62
+ )
63
+
57
64
 
58
65
  @torch.inference_mode
59
66
  def convert_stable_diffusion_to_tflite(
@@ -111,7 +118,7 @@ def convert_stable_diffusion_to_tflite(
111
118
  time_embedding = util.get_time_embedding(timestamp)
112
119
 
113
120
  if not os.path.exists(output_dir):
114
- Path(output_dir).mkdir(parents=True, exist_ok=True)
121
+ pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
115
122
 
116
123
  quant_config = (
117
124
  quant_recipes.full_int8_weight_only_recipe() if quantize else None
@@ -142,14 +149,15 @@ def convert_stable_diffusion_to_tflite(
142
149
  ).export(f'{output_dir}/decoder.tflite')
143
150
 
144
151
 
145
- if __name__ == '__main__':
146
- args = arg_parser.parse_args()
152
+ def main(_):
147
153
  convert_stable_diffusion_to_tflite(
148
- output_dir=args.output_dir,
149
- clip_ckpt_path=args.clip_ckpt,
150
- diffusion_ckpt_path=args.diffusion_ckpt,
151
- decoder_ckpt_path=args.decoder_ckpt,
152
- image_height=512,
153
- image_width=512,
154
- quantize=True,
154
+ output_dir=_OUTPUT_DIR.value,
155
+ clip_ckpt_path=_CLIP_CKPT.value,
156
+ diffusion_ckpt_path=_DIFFUSION_CKPT.value,
157
+ decoder_ckpt_path=_DECODER_CKPT.value,
158
+ quantize=_QUANTIZE.value,
155
159
  )
160
+
161
+
162
+ if __name__ == '__main__':
163
+ app.run(main)
@@ -15,102 +15,10 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
19
+ from ai_edge_torch.generative.utilities import model_builder
26
20
 
27
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
- ff_up_proj="model.layers.{}.mlp.up_proj",
29
- ff_down_proj="model.layers.{}.mlp.down_proj",
30
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
31
- attn_query_proj="model.layers.{}.self_attn.q_proj",
32
- attn_key_proj="model.layers.{}.self_attn.k_proj",
33
- attn_value_proj="model.layers.{}.self_attn.v_proj",
34
- attn_output_proj="model.layers.{}.self_attn.o_proj",
35
- pre_attn_norm="model.layers.{}.input_layernorm",
36
- post_attn_norm="model.layers.{}.post_attention_layernorm",
37
- embedding="model.embed_tokens",
38
- final_norm="model.norm",
39
- lm_head="lm_head",
40
- )
41
-
42
-
43
- class TinyLlama(nn.Module):
44
- """A TinyLlama model built from the Edge Generative API layers."""
45
-
46
- def __init__(self, config: cfg.ModelConfig):
47
- super().__init__()
48
-
49
- # Construct model layers.
50
- self.lm_head = nn.Linear(
51
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
52
- )
53
- self.tok_embedding = nn.Embedding(
54
- config.vocab_size, config.embedding_dim, padding_idx=0
55
- )
56
- # TinyLlama has only one block config.
57
- block_config = config.block_config(0)
58
- self.transformer_blocks = nn.ModuleList(
59
- attention.TransformerBlock(block_config, config)
60
- for _ in range(config.num_layers)
61
- )
62
- self.final_norm = builder.build_norm(
63
- config.embedding_dim,
64
- config.final_norm_config,
65
- )
66
- attn_config = block_config.attn_config
67
- self.rope_cache = attn_utils.build_rope_cache(
68
- size=config.kv_cache_max,
69
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=attn_config.rotary_base,
71
- )
72
- self.mask_cache = attn_utils.build_causal_mask_cache(
73
- size=config.kv_cache_max,
74
- )
75
- self.config = config
76
-
77
- @torch.inference_mode
78
- def forward(
79
- self,
80
- tokens: torch.Tensor,
81
- input_pos: torch.Tensor,
82
- kv_cache: kv_utils.KVCache,
83
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
84
- _, seq_len = tokens.size()
85
- assert self.config.max_seq_len >= seq_len, (
86
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
87
- f" {self.config.max_seq_len}"
88
- )
89
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
90
- "The number of transformer blocks and the number of KV cache entries"
91
- " must be the same."
92
- )
93
-
94
- cos, sin = self.rope_cache
95
- cos = cos.index_select(0, input_pos)
96
- sin = sin.index_select(0, input_pos)
97
- mask = self.mask_cache.index_select(2, input_pos)
98
- mask = mask[:, :, :, : self.config.kv_cache_max]
99
-
100
- # token embeddings of shape (b, t, n_embd)
101
- x = self.tok_embedding(tokens)
102
-
103
- updated_kv_entires = []
104
- for i, block in enumerate(self.transformer_blocks):
105
- kv_entry = kv_cache.caches[i] if kv_cache else None
106
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
107
- if kv_entry:
108
- updated_kv_entires.append(kv_entry)
109
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
-
111
- x = self.final_norm(x)
112
- logits = self.lm_head(x) # (b, t, vocab_size)
113
- return {"logits": logits, "kv_cache": updated_kv_cache}
21
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
114
22
 
115
23
 
116
24
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
58
  kv_cache_max_len=kv_cache_max_len,
151
59
  block_configs=block_config,
152
60
  final_norm_config=norm_config,
61
+ lm_head_share_weight_with_embedding=False,
153
62
  enable_hlfb=True,
154
63
  )
155
64
  return config
@@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
164
73
  return config
165
74
 
166
75
 
167
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
168
- config = get_model_config(**kwargs)
169
- model = TinyLlama(config)
170
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
171
- loader.load(model)
172
- model.eval()
173
- return model
76
+ def build_model(
77
+ checkpoint_path: str, **kwargs
78
+ ) -> model_builder.DecoderOnlyModel:
79
+ return model_builder.build_decoder_only_model(
80
+ checkpoint_path=checkpoint_path,
81
+ config=get_model_config(**kwargs),
82
+ tensor_names=TENSOR_NAMES,
83
+ )
@@ -184,8 +184,14 @@ class ModelConfig:
184
184
  default_factory=NormalizationConfig
185
185
  )
186
186
 
187
+ # Scale factor of the embedding.
188
+ embedding_scale: Optional[float] = None
189
+
187
190
  # Use bias term within LLM's HEAD.
188
191
  lm_head_use_bias: bool = False
192
+ # Whether LLM's HEAD shares the weight of the embedding.
193
+ lm_head_share_weight_with_embedding: bool = True
194
+
189
195
  # Whether to turn on high-level function boundary.
190
196
  enable_hlfb: bool = False
191
197
 
@@ -19,6 +19,7 @@ import tempfile
19
19
 
20
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
21
  from ai_edge_torch.generative.utilities import loader as loading_utils
22
+ from ai_edge_torch.generative.utilities import model_builder
22
23
  import safetensors.torch
23
24
  import torch
24
25
 
@@ -71,7 +72,7 @@ class TestLoader(googletest.TestCase):
71
72
  safetensors.torch.save_file(test_weights, file_path)
72
73
  cfg = tiny_llama.get_model_config()
73
74
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLlama(cfg)
75
+ model = model_builder.DecoderOnlyModel(cfg)
75
76
 
76
77
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
78
  # if returns successfully, it means all the tensors were initiallized.
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
21
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
22
  from ai_edge_torch.generative.layers import kv_cache
23
23
  from ai_edge_torch.generative.test import utils as test_utils
24
+ from ai_edge_torch.generative.utilities import model_builder
24
25
  import numpy as np
25
26
  import torch
26
27
 
@@ -42,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
42
43
  )
43
44
  )
44
45
 
45
- def _test_model_with_kv_cache(self, config, pytorch_model):
46
+ def _get_params(self, enable_hlfb: bool):
47
+ """Returns a model, edge model and the kwargs to use for testing."""
48
+ config = toy_model_with_kv_cache.get_model_config()
49
+ config.enable_hlfb = enable_hlfb
50
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
46
51
  tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
47
52
  [10], dtype=torch.int
48
53
  )
49
54
  kv = kv_cache.KVCache.from_model_config(config)
55
+ kwargs = {
56
+ "tokens": tokens,
57
+ "input_pos": input_pos,
58
+ "kv_cache": kv,
59
+ }
50
60
 
51
61
  edge_model = ai_edge_torch.convert(
52
62
  pytorch_model,
53
- sample_kwargs={
54
- "tokens": tokens,
55
- "input_pos": input_pos,
56
- "kv_cache": kv,
57
- },
63
+ sample_kwargs=kwargs,
58
64
  )
59
65
  edge_model.set_interpreter_builder(
60
66
  self._interpreter_builder(edge_model.tflite_model())
61
67
  )
68
+ return pytorch_model, edge_model, kwargs
69
+
70
+ def _test_model_with_kv_cache(self, enable_hlfb: bool):
71
+ pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
62
72
 
63
73
  self.assertTrue(
64
74
  test_utils.compare_tflite_torch(
65
75
  edge_model,
66
76
  pytorch_model,
67
- tokens,
68
- input_pos,
69
- kv,
77
+ kwargs["tokens"],
78
+ kwargs["input_pos"],
79
+ kwargs["kv_cache"],
70
80
  signature_name="serving_default",
71
81
  atol=1e-5,
72
82
  rtol=1e-5,
@@ -78,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
78
88
  reason="tests with custom ops are not supported on oss",
79
89
  )
80
90
  def test_toy_model_with_kv_cache(self):
81
- config = toy_model_with_kv_cache.get_model_config()
82
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
83
- self._test_model_with_kv_cache(config, pytorch_model)
91
+ self._test_model_with_kv_cache(enable_hlfb=False)
84
92
 
85
93
  @googletest.skipIf(
86
94
  ai_edge_config.Config.use_torch_xla,
87
95
  reason="tests with custom ops are not supported on oss",
88
96
  )
89
97
  def test_toy_model_with_kv_cache_with_hlfb(self):
90
- config = toy_model_with_kv_cache.get_model_config()
91
- config.enable_hlfb = True
92
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
93
- self._test_model_with_kv_cache(config, pytorch_model)
98
+ self._test_model_with_kv_cache(enable_hlfb=True)
99
+
100
+ @googletest.skipIf(
101
+ ai_edge_config.Config.use_torch_xla,
102
+ reason="tests with custom ops are not supported on oss",
103
+ )
104
+ def test_toy_model_has_ekv_op(self):
105
+ """Tests that the model has the external kv cache op."""
106
+ _, edge_model, _ = self._get_params(enable_hlfb=True)
107
+ interpreter_ = interpreter.InterpreterWithCustomOps(
108
+ custom_op_registerers=["GenAIOpsRegisterer"],
109
+ model_content=edge_model.tflite_model(),
110
+ experimental_default_delegate_latest_features=True,
111
+ )
112
+
113
+ # pylint: disable=protected-access
114
+ op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
115
+ self.assertIn("odml.update_external_kv_cache", op_names)
94
116
 
95
117
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
96
118
  # prefill
@@ -163,7 +185,7 @@ class TestModelConversion(googletest.TestCase):
163
185
  )
164
186
  def test_tiny_llama_multisig(self):
165
187
  config = tiny_llama.get_fake_model_config()
166
- pytorch_model = tiny_llama.TinyLlama(config).eval()
188
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
167
189
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
168
190
 
169
191
 
@@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
29
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
30
30
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
31
31
  from ai_edge_torch.generative.layers import kv_cache
32
+ from ai_edge_torch.generative.utilities import model_builder
32
33
  from ai_edge_torch.generative.test import utils as test_utils
33
34
  import numpy as np
34
35
  import torch
@@ -90,7 +91,7 @@ class TestModelConversion(googletest.TestCase):
90
91
  )
91
92
  def test_gemma1(self):
92
93
  config = gemma1.get_fake_model_config()
93
- pytorch_model = gemma1.Gemma(config).eval()
94
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
94
95
  self._test_model(
95
96
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
96
97
  )
@@ -119,7 +120,7 @@ class TestModelConversion(googletest.TestCase):
119
120
  )
120
121
  def test_phi2(self):
121
122
  config = phi2.get_fake_model_config()
122
- pytorch_model = phi2.Phi2(config).eval()
123
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
123
124
  self._test_model(
124
125
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
125
126
  )
@@ -139,7 +140,7 @@ class TestModelConversion(googletest.TestCase):
139
140
  )
140
141
  def test_smollm(self):
141
142
  config = smollm.get_fake_model_config()
142
- pytorch_model = smollm.SmolLM(config).eval()
143
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
143
144
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
144
145
 
145
146
  @googletest.skipIf(
@@ -148,7 +149,7 @@ class TestModelConversion(googletest.TestCase):
148
149
  )
149
150
  def test_openelm(self):
150
151
  config = openelm.get_fake_model_config()
151
- pytorch_model = openelm.OpenELM(config).eval()
152
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
152
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
153
154
 
154
155
  @googletest.skipIf(
@@ -157,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
157
158
  )
158
159
  def test_qwen(self):
159
160
  config = qwen.get_fake_model_config()
160
- pytorch_model = qwen.Qwen(config).eval()
161
+ pytorch_model = model_builder.DecoderOnlyModel(config).eval()
161
162
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
162
163
 
163
164
  @googletest.skipIf(