ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241213__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (45) hide show
  1. ai_edge_torch/debug/test/test_culprit.py +8 -3
  2. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  3. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  6. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  7. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  10. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  12. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  13. ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
  14. ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
  15. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  16. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  17. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  18. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  19. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  20. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  21. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  22. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  23. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  24. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  25. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  26. ai_edge_torch/generative/layers/attention.py +2 -6
  27. ai_edge_torch/generative/layers/kv_cache.py +25 -18
  28. ai_edge_torch/generative/layers/normalization.py +1 -3
  29. ai_edge_torch/generative/test/test_kv_cache.py +3 -3
  30. ai_edge_torch/generative/test/test_model_conversion.py +4 -5
  31. ai_edge_torch/generative/test/test_model_conversion_large.py +37 -32
  32. ai_edge_torch/generative/test/utils.py +31 -6
  33. ai_edge_torch/generative/utilities/converter.py +25 -4
  34. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  35. ai_edge_torch/generative/utilities/verifier.py +16 -2
  36. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  37. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
  38. ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
  39. ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
  40. ai_edge_torch/version.py +1 -1
  41. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/METADATA +2 -2
  42. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/RECORD +45 -44
  43. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/LICENSE +0 -0
  44. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/WHEEL +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
22
23
 
23
24
 
25
+ class TinyLlama(model_builder.DecoderOnlyModel):
26
+ """A TinyLlama model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a TinyLlama model.
26
32
 
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
73
79
  return config
74
80
 
75
81
 
76
- def build_model(
77
- checkpoint_path: str, **kwargs
78
- ) -> model_builder.DecoderOnlyModel:
82
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
79
83
  return model_builder.build_decoder_only_model(
80
84
  checkpoint_path=checkpoint_path,
81
85
  config=get_model_config(**kwargs),
82
86
  tensor_names=TENSOR_NAMES,
87
+ model_class=TinyLlama,
83
88
  )
@@ -241,9 +241,7 @@ class CausalSelfAttention(nn.Module):
241
241
  q, k = _embed_rope(q, k, n_elem, rope)
242
242
 
243
243
  if kv_cache is not None:
244
- kv_cache = kv_utils.update(
245
- kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
246
- )
244
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
247
245
  k, v = kv_cache.k_cache, kv_cache.v_cache
248
246
 
249
247
  y = self.sdpa_func(
@@ -379,9 +377,7 @@ class CrossAttention(nn.Module):
379
377
  q, k = _embed_rope(q, k, n_elem, rope)
380
378
 
381
379
  if kv_cache is not None:
382
- kv_cache = kv_utils.update(
383
- kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
384
- )
380
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
385
381
  k, v = kv_cache.k_cache, kv_cache.v_cache
386
382
  if mask is None:
387
383
  mask = torch.zeros(
@@ -146,7 +146,7 @@ def update(
146
146
  input_pos: torch.Tensor,
147
147
  k_slice: torch.Tensor,
148
148
  v_slice: torch.Tensor,
149
- enable_hlfb: bool = True,
149
+ use_dus: bool = True,
150
150
  ) -> KVCacheEntry:
151
151
  """Out of place update of Cache buffer.
152
152
 
@@ -155,17 +155,14 @@ def update(
155
155
  input_pos (torch.Tensor): The update slice positions.
156
156
  k_slice (torch.Tensor): The K slice to be updated in the new cache.
157
157
  v_slice (torch.Tensor): The V slice to be updated in the new cache.
158
- enable_hlfb (bool, optional): Whether the op is annotated for export with
159
- High Level Function Boundary. Defaults to True.
160
158
 
161
159
  Returns:
162
160
  KVCacheEntry: The updated KVCache entry based on the passed inputs.
163
161
  """
164
- # Don't enable HLFB for kv cache op for now, since it won't work with LLM
165
- # inference engine. Remove this part once we ship a new LLM inference engine.
166
- enable_hlfb=False
167
- update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
168
- return update_func(cache, input_pos, k_slice, v_slice)
162
+ # Turn dynamic_update_slice updates off for now.
163
+ use_dus=False
164
+ update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
165
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
169
166
 
170
167
 
171
168
  def _update_kv_base_impl(
@@ -181,18 +178,28 @@ def _update_kv_base_impl(
181
178
  return updated_cache
182
179
 
183
180
 
184
- def _update_kv_hlfb_impl(
181
+ def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
182
+ """Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
183
+
184
+ zero = torch.zeros([]).int()
185
+ positions = positions.int()[0].reshape([])
186
+ return [zero, positions, zero, zero]
187
+
188
+
189
+ def _update_kv_impl(
185
190
  cache: KVCacheEntry,
186
191
  input_pos: torch.Tensor,
187
192
  k_slice: torch.Tensor,
188
193
  v_slice: torch.Tensor,
189
194
  ) -> KVCacheEntry:
190
- """Update the cache buffer with High Level Function Boundary annotation."""
191
- builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
192
- k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
193
- cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
194
- )
195
- k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
196
- v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
197
- k, v = builder.mark_outputs(k, v)
198
- return KVCacheEntry(k, v)
195
+ """Update the cache buffer for K and V caches."""
196
+ # NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
197
+
198
+ k_slice_indices = _get_slice_indices(input_pos)
199
+ v_slice_indices = _get_slice_indices(input_pos)
200
+
201
+ k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
202
+ v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
203
+
204
+ updated_cache = KVCacheEntry(k, v)
205
+ return updated_cache
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
190
190
  """
191
191
  x = torch.permute(x, (0, 2, 3, 1))
192
192
 
193
- # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
194
- # int32 when the bug is fixed.
195
193
  builder = StableHLOCompositeBuilder(
196
194
  name="odml.group_norm",
197
195
  attr={
198
196
  "num_groups": num_groups,
199
197
  "epsilon": eps,
200
- "reduction_axes": 3,
198
+ "reduction_axes": [3],
201
199
  "channel_axis": 3,
202
200
  },
203
201
  )
@@ -71,18 +71,18 @@ class TestKVLayers(googletest.TestCase):
71
71
  [0, 0, 5, 5, 0, 0, 0, 0],
72
72
  )
73
73
  # multi-slice update
74
- input_pos = torch.tensor([0, 3])
74
+ input_pos = torch.tensor([0, 1])
75
75
  k_slice = v_slice = torch.full(
76
76
  (1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
77
77
  )
78
78
  updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
79
79
  self.assertEqual(
80
80
  updated_entry.k_cache.numpy().flatten().tolist(),
81
- [7, 7, 0, 0, 0, 0, 7, 7],
81
+ [7, 7, 7, 7, 0, 0, 0, 0],
82
82
  )
83
83
  self.assertEqual(
84
84
  updated_entry.v_cache.numpy().flatten().tolist(),
85
- [7, 7, 0, 0, 0, 0, 7, 7],
85
+ [7, 7, 7, 7, 0, 0, 0, 0],
86
86
  )
87
87
 
88
88
  def test_serialization(self):
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
21
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
22
  from ai_edge_torch.generative.layers import kv_cache
23
23
  from ai_edge_torch.generative.test import utils as test_utils
24
- from ai_edge_torch.generative.utilities import model_builder
25
24
  import numpy as np
26
25
  import torch
27
26
 
@@ -101,8 +100,8 @@ class TestModelConversion(googletest.TestCase):
101
100
  ai_edge_config.Config.use_torch_xla,
102
101
  reason="tests with custom ops are not supported on oss",
103
102
  )
104
- def test_toy_model_has_ekv_op(self):
105
- """Tests that the model has the external kv cache op."""
103
+ def test_toy_model_has_dus_op(self):
104
+ """Tests that the model has the dynamic update slice op."""
106
105
  _, edge_model, _ = self._get_params(enable_hlfb=True)
107
106
  interpreter_ = interpreter.InterpreterWithCustomOps(
108
107
  custom_op_registerers=["GenAIOpsRegisterer"],
@@ -112,7 +111,7 @@ class TestModelConversion(googletest.TestCase):
112
111
 
113
112
  # pylint: disable=protected-access
114
113
  op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
115
- self.assertIn("odml.update_external_kv_cache", op_names)
114
+ self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
116
115
 
117
116
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
118
117
  # prefill
@@ -185,7 +184,7 @@ class TestModelConversion(googletest.TestCase):
185
184
  )
186
185
  def test_tiny_llama_multisig(self):
187
186
  config = tiny_llama.get_fake_model_config()
188
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
187
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
189
188
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
190
189
 
191
190
 
@@ -32,7 +32,6 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_dec
32
32
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
33
33
  from ai_edge_torch.generative.layers import kv_cache
34
34
  from ai_edge_torch.generative.test import utils as test_utils
35
- from ai_edge_torch.generative.utilities import model_builder
36
35
  import numpy as np
37
36
  import torch
38
37
 
@@ -53,12 +52,15 @@ class TestModelConversion(googletest.TestCase):
53
52
  experimental_default_delegate_latest_features=True,
54
53
  )
55
54
  )
55
+ # Default cache_size_limit, 8 is hit and aborts often when the tests are
56
+ # running all together. Doubles it to avoid abortion.
57
+ torch._dynamo.config.cache_size_limit = 16
58
+ np.random.seed(1234) # Make np.random deterministic.
56
59
 
57
60
  def _test_model(self, config, model, signature_name, atol, rtol):
58
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
59
- tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
60
- tokens[0, :4] = idx
61
- input_pos = torch.arange(0, 10, dtype=torch.int)
61
+ seq_len = 10
62
+ tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
63
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
62
64
  kv = kv_cache.KVCache.from_model_config(config)
63
65
 
64
66
  edge_model = ai_edge_torch.signature(
@@ -74,6 +76,7 @@ class TestModelConversion(googletest.TestCase):
74
76
  self._interpreter_builder(edge_model.tflite_model())
75
77
  )
76
78
 
79
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
77
80
  self.assertTrue(
78
81
  test_utils.compare_tflite_torch(
79
82
  edge_model,
@@ -93,10 +96,8 @@ class TestModelConversion(googletest.TestCase):
93
96
  )
94
97
  def test_gemma1(self):
95
98
  config = gemma1.get_fake_model_config()
96
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
97
- self._test_model(
98
- config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
99
- )
99
+ pytorch_model = gemma1.Gemma1(config).eval()
100
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
100
101
 
101
102
  @googletest.skipIf(
102
103
  ai_edge_config.Config.use_torch_xla,
@@ -122,10 +123,9 @@ class TestModelConversion(googletest.TestCase):
122
123
  )
123
124
  def test_phi2(self):
124
125
  config = phi2.get_fake_model_config()
125
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
126
- self._test_model(
127
- config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
128
- )
126
+ pytorch_model = phi2.Phi2(config).eval()
127
+ # Phi-2 logits are very big, so we need a larger absolute tolerance.
128
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
129
129
 
130
130
  @googletest.skipIf(
131
131
  ai_edge_config.Config.use_torch_xla,
@@ -142,7 +142,7 @@ class TestModelConversion(googletest.TestCase):
142
142
  )
143
143
  def test_smollm(self):
144
144
  config = smollm.get_fake_model_config()
145
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
145
+ pytorch_model = smollm.SmolLM(config).eval()
146
146
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
147
147
 
148
148
  @googletest.skipIf(
@@ -151,7 +151,7 @@ class TestModelConversion(googletest.TestCase):
151
151
  )
152
152
  def test_openelm(self):
153
153
  config = openelm.get_fake_model_config()
154
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
154
+ pytorch_model = openelm.OpenELM(config).eval()
155
155
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
156
 
157
157
  @googletest.skipIf(
@@ -160,7 +160,7 @@ class TestModelConversion(googletest.TestCase):
160
160
  )
161
161
  def test_qwen(self):
162
162
  config = qwen.get_fake_model_config()
163
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
163
+ pytorch_model = qwen.Qwen(config).eval()
164
164
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
165
165
 
166
166
  @googletest.skipIf(
@@ -169,26 +169,26 @@ class TestModelConversion(googletest.TestCase):
169
169
  )
170
170
  def test_amd_llama_135m(self):
171
171
  config = amd_llama_135m.get_fake_model_config()
172
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
173
- self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
172
+ pytorch_model = amd_llama_135m.AmdLlama(config).eval()
173
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
174
174
 
175
175
  @googletest.skipIf(
176
176
  ai_edge_config.Config.use_torch_xla,
177
177
  reason="tests with custom ops are not supported on oss",
178
178
  )
179
- def test_paligemma(self):
179
+ def disabled_test_paligemma(self):
180
180
  config = paligemma.get_fake_model_config()
181
181
  pytorch_model = paligemma.PaliGemma(config).eval()
182
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
182
+
183
183
  image_embedding_config = config.image_encoder_config.image_embedding
184
184
  num_patches = (
185
185
  image_embedding_config.image_size // image_embedding_config.patch_size
186
186
  ) ** 2
187
+
187
188
  # Make sure the token size is longer than the number of image patches.
188
- tokens_len = num_patches + 10
189
- tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
190
- tokens[0, :4] = idx
191
- input_pos = torch.arange(0, tokens_len, dtype=torch.int)
189
+ seq_len = num_patches + 10
190
+ tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
191
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
192
192
  kv = kv_cache.KVCache.from_model_config(config.decoder_config)
193
193
  pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
194
194
 
@@ -206,6 +206,7 @@ class TestModelConversion(googletest.TestCase):
206
206
  self._interpreter_builder(edge_model.tflite_model())
207
207
  )
208
208
 
209
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
209
210
  self.assertTrue(
210
211
  test_utils.compare_tflite_torch(
211
212
  edge_model,
@@ -244,7 +245,7 @@ class TestModelConversion(googletest.TestCase):
244
245
  signature_name="encode",
245
246
  )
246
247
  self.assertTrue(
247
- np.allclose(
248
+ test_utils.compare_logits(
248
249
  edge_output,
249
250
  torch_output.detach().numpy(),
250
251
  atol=1e-4,
@@ -258,14 +259,16 @@ class TestModelConversion(googletest.TestCase):
258
259
  )
259
260
  def test_stable_diffusion_diffusion(self):
260
261
  config = sd_diffusion.get_fake_model_config(2)
262
+ # Reduce stddev(scale) of input values to avoid too big output logits which
263
+ # fails comparisons with reasonable tolerances.
261
264
  latents = torch.from_numpy(
262
- np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
265
+ np.random.normal(size=(2, 4, 8, 8), scale=0.1).astype(np.float32)
263
266
  )
264
267
  context = torch.from_numpy(
265
- np.random.normal(size=(2, 4, 4)).astype(np.float32)
268
+ np.random.normal(size=(2, 4, 4), scale=0.1).astype(np.float32)
266
269
  )
267
270
  time_embedding = torch.from_numpy(
268
- np.random.normal(size=(2, 2)).astype(np.float32)
271
+ np.random.normal(size=(2, 2), scale=0.1).astype(np.float32)
269
272
  )
270
273
 
271
274
  pytorch_model = sd_diffusion.Diffusion(config).eval()
@@ -284,7 +287,7 @@ class TestModelConversion(googletest.TestCase):
284
287
  signature_name="diffusion",
285
288
  )
286
289
  self.assertTrue(
287
- np.allclose(
290
+ test_utils.compare_logits(
288
291
  edge_output,
289
292
  torch_output.detach().numpy(),
290
293
  atol=1e-4,
@@ -298,8 +301,10 @@ class TestModelConversion(googletest.TestCase):
298
301
  )
299
302
  def test_stable_diffusion_decoder(self):
300
303
  config = sd_decoder.get_fake_model_config()
304
+ # Reduce stddev(scale) of input values to avoid too big output logits which
305
+ # fails comparisons with reasonable tolerances.
301
306
  latents = torch.from_numpy(
302
- np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
307
+ np.random.normal(size=(1, 4, 64, 64), scale=0.1).astype(np.float32)
303
308
  )
304
309
 
305
310
  pytorch_model = sd_decoder.Decoder(config).eval()
@@ -316,10 +321,10 @@ class TestModelConversion(googletest.TestCase):
316
321
  signature_name="decode",
317
322
  )
318
323
  self.assertTrue(
319
- np.allclose(
324
+ test_utils.compare_logits(
320
325
  edge_output,
321
326
  torch_output.detach().numpy(),
322
- atol=1e-4,
327
+ atol=1e-3,
323
328
  rtol=1e-5,
324
329
  )
325
330
  )
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Common utils for testing."""
17
17
 
18
+ import logging
19
+
18
20
  from ai_edge_torch import model
19
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
22
  from ai_edge_torch.lowertools import common_utils
@@ -33,7 +35,7 @@ def compare_tflite_torch(
33
35
  atol: float = 1e-5,
34
36
  rtol: float = 1e-5,
35
37
  **kwargs,
36
- ):
38
+ ) -> bool:
37
39
  """Compares torch models and TFLite models."""
38
40
  values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
39
41
  flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
@@ -49,9 +51,32 @@ def compare_tflite_torch(
49
51
  **kwargs,
50
52
  )
51
53
 
52
- return np.allclose(
53
- edge_output["logits"],
54
- torch_output["logits"].detach().numpy(),
55
- atol=atol,
56
- rtol=rtol,
54
+ return compare_logits(
55
+ edge_output["logits"], torch_output["logits"].detach().numpy(), atol, rtol
57
56
  )
57
+
58
+
59
+ def compare_logits(
60
+ edge_logits: np.ndarray,
61
+ torch_logits: dict[str, torch.Tensor],
62
+ atol: float = 1e-5,
63
+ rtol: float = 1e-5,
64
+ ) -> bool:
65
+ """Compares logits from edge model and torch model."""
66
+ if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
67
+ return True
68
+
69
+ logging.info("edge_logits: %s", edge_logits)
70
+ logging.info("torch_logits: %s", torch_logits)
71
+
72
+ orig_atol = atol
73
+ while rtol < 1:
74
+ atol = orig_atol
75
+ while atol < 1:
76
+ if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
77
+ logging.info("Got allclose true with atol=%s, rtol=%s", atol, rtol)
78
+ return False
79
+ atol *= 10
80
+ rtol *= 10
81
+ logging.info("allclose failed with reasonable atol and rtol.")
82
+ return False
@@ -15,13 +15,28 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from typing import Union
18
+ from functools import partial
19
+ from typing import Any, Union
19
20
 
20
21
  from ai_edge_torch._convert import converter as converter_utils
21
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
22
23
  import ai_edge_torch.generative.layers.model_config as cfg
23
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
24
26
  import torch
27
+ import torch.nn as nn
28
+
29
+
30
+ class ExportableModule(torch.nn.Module):
31
+
32
+ def __init__(self, module, **extra_kwargs):
33
+ super().__init__()
34
+ self.module = module
35
+ self.extra_kwargs = extra_kwargs
36
+
37
+ def forward(self, *export_args, **export_kwargs):
38
+ full_kwargs = {**export_kwargs, **self.extra_kwargs}
39
+ return self.module(*export_args, **full_kwargs)
25
40
 
26
41
 
27
42
  def convert_to_tflite(
@@ -31,6 +46,7 @@ def convert_to_tflite(
31
46
  pixel_values_size: torch.Size = None,
32
47
  quantize: bool = True,
33
48
  config: cfg.ModelConfig = None,
49
+ export_config: ExportConfig = None,
34
50
  ):
35
51
  """Converts a nn.Module model to multi-signature tflite model.
36
52
 
@@ -97,6 +113,11 @@ def convert_to_tflite(
97
113
  )
98
114
 
99
115
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
+
117
+ # For export, we create a module that captures any non-exportable,
118
+ # arugments, e.g. the generation config object.
119
+ mod = ExportableModule(pytorch_model, export_config=export_config)
120
+
100
121
  converter = converter_utils.Converter()
101
122
  for i in range(len(prefill_seq_lens)):
102
123
  prefill_seq_len = prefill_seq_lens[i]
@@ -108,7 +129,7 @@ def convert_to_tflite(
108
129
  prefill_signature_name = f'prefill_{prefill_seq_len}'
109
130
  converter.add_signature(
110
131
  prefill_signature_name,
111
- pytorch_model,
132
+ mod,
112
133
  sample_kwargs={
113
134
  'tokens': prefill_tokens,
114
135
  'input_pos': prefill_input_pos,
@@ -118,7 +139,7 @@ def convert_to_tflite(
118
139
  if prefill_pixel_values is not None:
119
140
  converter.add_signature(
120
141
  prefill_signature_name + '_pixel',
121
- pytorch_model,
142
+ mod,
122
143
  sample_kwargs={
123
144
  'tokens': prefill_tokens,
124
145
  'input_pos': prefill_input_pos,
@@ -129,7 +150,7 @@ def convert_to_tflite(
129
150
 
130
151
  converter.add_signature(
131
152
  'decode',
132
- pytorch_model,
153
+ mod,
133
154
  sample_kwargs={
134
155
  'tokens': decode_token,
135
156
  'input_pos': decode_input_pos,
@@ -16,7 +16,8 @@
16
16
  """Utilities to be used for re-authoring transformer models."""
17
17
 
18
18
  import copy
19
- from typing import Tuple
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
45
46
  TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
46
47
 
47
48
 
49
+ @dataclass
50
+ class ExportConfig:
51
+ """Model generating configuration settings."""
52
+
53
+ # On prefill signatures, should the model produce logit output?
54
+ # When False, only decode signatures will produce output.
55
+ output_logits_on_prefill: bool = False
56
+
57
+
48
58
  class DecoderOnlyModel(nn.Module):
49
59
  """A simple decoder-only transformer model built from the Edge Generative API.
50
60
 
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
93
103
  tokens: torch.Tensor,
94
104
  input_pos: torch.Tensor,
95
105
  kv_cache: kv_utils.KVCache,
106
+ export_config: Optional[ExportConfig] = None,
96
107
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
108
  _, seq_len = tokens.size()
98
109
  assert self.config.max_seq_len >= seq_len, (
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
108
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
109
120
 
110
121
  return self.forward_with_embeds(
111
- input_embeds, rope, mask, input_pos, kv_cache
122
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
112
123
  )
113
124
 
114
125
  def forward_with_embeds(
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
118
129
  mask: torch.Tensor,
119
130
  input_pos: torch.Tensor,
120
131
  kv_cache: kv_utils.KVCache,
132
+ export_config: Optional[ExportConfig] = None,
121
133
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
122
134
  """Forwards the model with input embeddings."""
123
135
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
137
149
  updated_kv_entires.append(kv_entry)
138
150
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
139
151
 
152
+ if export_config is not None:
153
+ if (
154
+ torch.numel(input_pos) > 1
155
+ and not export_config.output_logits_on_prefill
156
+ ):
157
+ return {"kv_cache": updated_kv_cache}
158
+
140
159
  x = self.final_norm(x)
141
160
  logits = self.lm_head(x) # (b, t, vocab_size)
142
161
  return {"logits": logits, "kv_cache": updated_kv_cache}
@@ -146,8 +165,9 @@ def build_decoder_only_model(
146
165
  checkpoint_path: str,
147
166
  config: cfg.ModelConfig,
148
167
  tensor_names: loading_utils.ModelLoader.TensorNames,
149
- ) -> DecoderOnlyModel:
150
- transformer = DecoderOnlyModel(config)
168
+ model_class: type[nn.Module] = DecoderOnlyModel,
169
+ ) -> nn.Module:
170
+ transformer = model_class(config)
151
171
  loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
152
172
  loader.load(
153
173
  transformer, strict=not config.lm_head_share_weight_with_embedding
@@ -19,6 +19,7 @@ import logging
19
19
  from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
22
23
  import torch
23
24
 
24
25
 
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
40
41
  """
41
42
  super().__init__()
42
43
  self.model = model
44
+ self.export_config = ExportConfig(output_logits_on_prefill=True)
43
45
 
44
46
  def forward(
45
47
  self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
103
105
  Returns:
104
106
  The output logits and the updated KV cache.
105
107
  """
108
+ # Verification requires logit outputs on prefill for comparison.
109
+ if (
110
+ self.export_config is not None
111
+ and not self.export_config.output_logits_on_prefill
112
+ ):
113
+ raise ValueError("Verifier requires logit output on prefill.")
106
114
  # Since the reauthored model doesn't include keyword arguments, pass
107
115
  # pixel_values only when it is not None. Otherwise, it may raise an error.
108
116
  if pixel_values is None:
109
- output = self.model.forward(tokens, input_pos, kv_cache)
117
+ output = self.model.forward(
118
+ tokens, input_pos, kv_cache, export_config=self.export_config
119
+ )
110
120
  else:
111
121
  output = self.model.forward(
112
- tokens, input_pos, kv_cache, pixel_values=pixel_values
122
+ tokens,
123
+ input_pos,
124
+ kv_cache,
125
+ pixel_values=pixel_values,
126
+ export_config=self.export_config,
113
127
  )
114
128
  return output["logits"], output["kv_cache"]
115
129
 
@@ -21,6 +21,6 @@ from . import _quantized_decomposed
21
21
  from . import context
22
22
  from . import registry
23
23
  from . import utils
24
- from .registry import decompositions
24
+ from .decomp import decompositions
25
25
  from .registry import lookup
26
26
  from .registry import lower