ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ai_edge_torch/debug/test/test_culprit.py +8 -3
  2. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  3. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  6. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  7. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  10. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  12. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  13. ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
  14. ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
  15. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  16. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  17. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  18. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  19. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  20. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  21. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  22. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  23. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  24. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  25. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  26. ai_edge_torch/generative/layers/attention.py +2 -6
  27. ai_edge_torch/generative/layers/kv_cache.py +25 -18
  28. ai_edge_torch/generative/layers/normalization.py +1 -3
  29. ai_edge_torch/generative/test/test_kv_cache.py +3 -3
  30. ai_edge_torch/generative/test/test_model_conversion.py +4 -5
  31. ai_edge_torch/generative/test/test_model_conversion_large.py +37 -32
  32. ai_edge_torch/generative/test/utils.py +31 -6
  33. ai_edge_torch/generative/utilities/converter.py +25 -4
  34. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  35. ai_edge_torch/generative/utilities/verifier.py +16 -2
  36. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  37. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
  38. ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
  39. ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
  40. ai_edge_torch/version.py +1 -1
  41. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/METADATA +2 -2
  42. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/RECORD +45 -44
  43. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/LICENSE +0 -0
  44. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/WHEEL +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
22
23
 
23
24
 
25
+ class TinyLlama(model_builder.DecoderOnlyModel):
26
+ """A TinyLlama model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a TinyLlama model.
26
32
 
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
73
79
  return config
74
80
 
75
81
 
76
- def build_model(
77
- checkpoint_path: str, **kwargs
78
- ) -> model_builder.DecoderOnlyModel:
82
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
79
83
  return model_builder.build_decoder_only_model(
80
84
  checkpoint_path=checkpoint_path,
81
85
  config=get_model_config(**kwargs),
82
86
  tensor_names=TENSOR_NAMES,
87
+ model_class=TinyLlama,
83
88
  )
@@ -241,9 +241,7 @@ class CausalSelfAttention(nn.Module):
241
241
  q, k = _embed_rope(q, k, n_elem, rope)
242
242
 
243
243
  if kv_cache is not None:
244
- kv_cache = kv_utils.update(
245
- kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
246
- )
244
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
247
245
  k, v = kv_cache.k_cache, kv_cache.v_cache
248
246
 
249
247
  y = self.sdpa_func(
@@ -379,9 +377,7 @@ class CrossAttention(nn.Module):
379
377
  q, k = _embed_rope(q, k, n_elem, rope)
380
378
 
381
379
  if kv_cache is not None:
382
- kv_cache = kv_utils.update(
383
- kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
384
- )
380
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
385
381
  k, v = kv_cache.k_cache, kv_cache.v_cache
386
382
  if mask is None:
387
383
  mask = torch.zeros(
@@ -146,7 +146,7 @@ def update(
146
146
  input_pos: torch.Tensor,
147
147
  k_slice: torch.Tensor,
148
148
  v_slice: torch.Tensor,
149
- enable_hlfb: bool = True,
149
+ use_dus: bool = True,
150
150
  ) -> KVCacheEntry:
151
151
  """Out of place update of Cache buffer.
152
152
 
@@ -155,17 +155,14 @@ def update(
155
155
  input_pos (torch.Tensor): The update slice positions.
156
156
  k_slice (torch.Tensor): The K slice to be updated in the new cache.
157
157
  v_slice (torch.Tensor): The V slice to be updated in the new cache.
158
- enable_hlfb (bool, optional): Whether the op is annotated for export with
159
- High Level Function Boundary. Defaults to True.
160
158
 
161
159
  Returns:
162
160
  KVCacheEntry: The updated KVCache entry based on the passed inputs.
163
161
  """
164
- # Don't enable HLFB for kv cache op for now, since it won't work with LLM
165
- # inference engine. Remove this part once we ship a new LLM inference engine.
166
- enable_hlfb=False
167
- update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
168
- return update_func(cache, input_pos, k_slice, v_slice)
162
+ # Turn dynamic_update_slice updates off for now.
163
+ use_dus=False
164
+ update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
165
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
169
166
 
170
167
 
171
168
  def _update_kv_base_impl(
@@ -181,18 +178,28 @@ def _update_kv_base_impl(
181
178
  return updated_cache
182
179
 
183
180
 
184
- def _update_kv_hlfb_impl(
181
+ def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
182
+ """Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
183
+
184
+ zero = torch.zeros([]).int()
185
+ positions = positions.int()[0].reshape([])
186
+ return [zero, positions, zero, zero]
187
+
188
+
189
+ def _update_kv_impl(
185
190
  cache: KVCacheEntry,
186
191
  input_pos: torch.Tensor,
187
192
  k_slice: torch.Tensor,
188
193
  v_slice: torch.Tensor,
189
194
  ) -> KVCacheEntry:
190
- """Update the cache buffer with High Level Function Boundary annotation."""
191
- builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
192
- k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
193
- cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
194
- )
195
- k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
196
- v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
197
- k, v = builder.mark_outputs(k, v)
198
- return KVCacheEntry(k, v)
195
+ """Update the cache buffer for K and V caches."""
196
+ # NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
197
+
198
+ k_slice_indices = _get_slice_indices(input_pos)
199
+ v_slice_indices = _get_slice_indices(input_pos)
200
+
201
+ k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
202
+ v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
203
+
204
+ updated_cache = KVCacheEntry(k, v)
205
+ return updated_cache
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
190
190
  """
191
191
  x = torch.permute(x, (0, 2, 3, 1))
192
192
 
193
- # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
194
- # int32 when the bug is fixed.
195
193
  builder = StableHLOCompositeBuilder(
196
194
  name="odml.group_norm",
197
195
  attr={
198
196
  "num_groups": num_groups,
199
197
  "epsilon": eps,
200
- "reduction_axes": 3,
198
+ "reduction_axes": [3],
201
199
  "channel_axis": 3,
202
200
  },
203
201
  )
@@ -71,18 +71,18 @@ class TestKVLayers(googletest.TestCase):
71
71
  [0, 0, 5, 5, 0, 0, 0, 0],
72
72
  )
73
73
  # multi-slice update
74
- input_pos = torch.tensor([0, 3])
74
+ input_pos = torch.tensor([0, 1])
75
75
  k_slice = v_slice = torch.full(
76
76
  (1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
77
77
  )
78
78
  updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
79
79
  self.assertEqual(
80
80
  updated_entry.k_cache.numpy().flatten().tolist(),
81
- [7, 7, 0, 0, 0, 0, 7, 7],
81
+ [7, 7, 7, 7, 0, 0, 0, 0],
82
82
  )
83
83
  self.assertEqual(
84
84
  updated_entry.v_cache.numpy().flatten().tolist(),
85
- [7, 7, 0, 0, 0, 0, 7, 7],
85
+ [7, 7, 7, 7, 0, 0, 0, 0],
86
86
  )
87
87
 
88
88
  def test_serialization(self):
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
21
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
22
  from ai_edge_torch.generative.layers import kv_cache
23
23
  from ai_edge_torch.generative.test import utils as test_utils
24
- from ai_edge_torch.generative.utilities import model_builder
25
24
  import numpy as np
26
25
  import torch
27
26
 
@@ -101,8 +100,8 @@ class TestModelConversion(googletest.TestCase):
101
100
  ai_edge_config.Config.use_torch_xla,
102
101
  reason="tests with custom ops are not supported on oss",
103
102
  )
104
- def test_toy_model_has_ekv_op(self):
105
- """Tests that the model has the external kv cache op."""
103
+ def test_toy_model_has_dus_op(self):
104
+ """Tests that the model has the dynamic update slice op."""
106
105
  _, edge_model, _ = self._get_params(enable_hlfb=True)
107
106
  interpreter_ = interpreter.InterpreterWithCustomOps(
108
107
  custom_op_registerers=["GenAIOpsRegisterer"],
@@ -112,7 +111,7 @@ class TestModelConversion(googletest.TestCase):
112
111
 
113
112
  # pylint: disable=protected-access
114
113
  op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
115
- self.assertIn("odml.update_external_kv_cache", op_names)
114
+ self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
116
115
 
117
116
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
118
117
  # prefill
@@ -185,7 +184,7 @@ class TestModelConversion(googletest.TestCase):
185
184
  )
186
185
  def test_tiny_llama_multisig(self):
187
186
  config = tiny_llama.get_fake_model_config()
188
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
187
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
189
188
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
190
189
 
191
190
 
@@ -32,7 +32,6 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_dec
32
32
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
33
33
  from ai_edge_torch.generative.layers import kv_cache
34
34
  from ai_edge_torch.generative.test import utils as test_utils
35
- from ai_edge_torch.generative.utilities import model_builder
36
35
  import numpy as np
37
36
  import torch
38
37
 
@@ -53,12 +52,15 @@ class TestModelConversion(googletest.TestCase):
53
52
  experimental_default_delegate_latest_features=True,
54
53
  )
55
54
  )
55
+ # Default cache_size_limit, 8 is hit and aborts often when the tests are
56
+ # running all together. Doubles it to avoid abortion.
57
+ torch._dynamo.config.cache_size_limit = 16
58
+ np.random.seed(1234) # Make np.random deterministic.
56
59
 
57
60
  def _test_model(self, config, model, signature_name, atol, rtol):
58
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
59
- tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
60
- tokens[0, :4] = idx
61
- input_pos = torch.arange(0, 10, dtype=torch.int)
61
+ seq_len = 10
62
+ tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
63
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
62
64
  kv = kv_cache.KVCache.from_model_config(config)
63
65
 
64
66
  edge_model = ai_edge_torch.signature(
@@ -74,6 +76,7 @@ class TestModelConversion(googletest.TestCase):
74
76
  self._interpreter_builder(edge_model.tflite_model())
75
77
  )
76
78
 
79
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
77
80
  self.assertTrue(
78
81
  test_utils.compare_tflite_torch(
79
82
  edge_model,
@@ -93,10 +96,8 @@ class TestModelConversion(googletest.TestCase):
93
96
  )
94
97
  def test_gemma1(self):
95
98
  config = gemma1.get_fake_model_config()
96
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
97
- self._test_model(
98
- config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
99
- )
99
+ pytorch_model = gemma1.Gemma1(config).eval()
100
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
100
101
 
101
102
  @googletest.skipIf(
102
103
  ai_edge_config.Config.use_torch_xla,
@@ -122,10 +123,9 @@ class TestModelConversion(googletest.TestCase):
122
123
  )
123
124
  def test_phi2(self):
124
125
  config = phi2.get_fake_model_config()
125
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
126
- self._test_model(
127
- config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
128
- )
126
+ pytorch_model = phi2.Phi2(config).eval()
127
+ # Phi-2 logits are very big, so we need a larger absolute tolerance.
128
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
129
129
 
130
130
  @googletest.skipIf(
131
131
  ai_edge_config.Config.use_torch_xla,
@@ -142,7 +142,7 @@ class TestModelConversion(googletest.TestCase):
142
142
  )
143
143
  def test_smollm(self):
144
144
  config = smollm.get_fake_model_config()
145
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
145
+ pytorch_model = smollm.SmolLM(config).eval()
146
146
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
147
147
 
148
148
  @googletest.skipIf(
@@ -151,7 +151,7 @@ class TestModelConversion(googletest.TestCase):
151
151
  )
152
152
  def test_openelm(self):
153
153
  config = openelm.get_fake_model_config()
154
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
154
+ pytorch_model = openelm.OpenELM(config).eval()
155
155
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
156
 
157
157
  @googletest.skipIf(
@@ -160,7 +160,7 @@ class TestModelConversion(googletest.TestCase):
160
160
  )
161
161
  def test_qwen(self):
162
162
  config = qwen.get_fake_model_config()
163
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
163
+ pytorch_model = qwen.Qwen(config).eval()
164
164
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
165
165
 
166
166
  @googletest.skipIf(
@@ -169,26 +169,26 @@ class TestModelConversion(googletest.TestCase):
169
169
  )
170
170
  def test_amd_llama_135m(self):
171
171
  config = amd_llama_135m.get_fake_model_config()
172
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
173
- self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
172
+ pytorch_model = amd_llama_135m.AmdLlama(config).eval()
173
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
174
174
 
175
175
  @googletest.skipIf(
176
176
  ai_edge_config.Config.use_torch_xla,
177
177
  reason="tests with custom ops are not supported on oss",
178
178
  )
179
- def test_paligemma(self):
179
+ def disabled_test_paligemma(self):
180
180
  config = paligemma.get_fake_model_config()
181
181
  pytorch_model = paligemma.PaliGemma(config).eval()
182
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
182
+
183
183
  image_embedding_config = config.image_encoder_config.image_embedding
184
184
  num_patches = (
185
185
  image_embedding_config.image_size // image_embedding_config.patch_size
186
186
  ) ** 2
187
+
187
188
  # Make sure the token size is longer than the number of image patches.
188
- tokens_len = num_patches + 10
189
- tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
190
- tokens[0, :4] = idx
191
- input_pos = torch.arange(0, tokens_len, dtype=torch.int)
189
+ seq_len = num_patches + 10
190
+ tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
191
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
192
192
  kv = kv_cache.KVCache.from_model_config(config.decoder_config)
193
193
  pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
194
194
 
@@ -206,6 +206,7 @@ class TestModelConversion(googletest.TestCase):
206
206
  self._interpreter_builder(edge_model.tflite_model())
207
207
  )
208
208
 
209
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
209
210
  self.assertTrue(
210
211
  test_utils.compare_tflite_torch(
211
212
  edge_model,
@@ -244,7 +245,7 @@ class TestModelConversion(googletest.TestCase):
244
245
  signature_name="encode",
245
246
  )
246
247
  self.assertTrue(
247
- np.allclose(
248
+ test_utils.compare_logits(
248
249
  edge_output,
249
250
  torch_output.detach().numpy(),
250
251
  atol=1e-4,
@@ -258,14 +259,16 @@ class TestModelConversion(googletest.TestCase):
258
259
  )
259
260
  def test_stable_diffusion_diffusion(self):
260
261
  config = sd_diffusion.get_fake_model_config(2)
262
+ # Reduce stddev(scale) of input values to avoid too big output logits which
263
+ # fails comparisons with reasonable tolerances.
261
264
  latents = torch.from_numpy(
262
- np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
265
+ np.random.normal(size=(2, 4, 8, 8), scale=0.1).astype(np.float32)
263
266
  )
264
267
  context = torch.from_numpy(
265
- np.random.normal(size=(2, 4, 4)).astype(np.float32)
268
+ np.random.normal(size=(2, 4, 4), scale=0.1).astype(np.float32)
266
269
  )
267
270
  time_embedding = torch.from_numpy(
268
- np.random.normal(size=(2, 2)).astype(np.float32)
271
+ np.random.normal(size=(2, 2), scale=0.1).astype(np.float32)
269
272
  )
270
273
 
271
274
  pytorch_model = sd_diffusion.Diffusion(config).eval()
@@ -284,7 +287,7 @@ class TestModelConversion(googletest.TestCase):
284
287
  signature_name="diffusion",
285
288
  )
286
289
  self.assertTrue(
287
- np.allclose(
290
+ test_utils.compare_logits(
288
291
  edge_output,
289
292
  torch_output.detach().numpy(),
290
293
  atol=1e-4,
@@ -298,8 +301,10 @@ class TestModelConversion(googletest.TestCase):
298
301
  )
299
302
  def test_stable_diffusion_decoder(self):
300
303
  config = sd_decoder.get_fake_model_config()
304
+ # Reduce stddev(scale) of input values to avoid too big output logits which
305
+ # fails comparisons with reasonable tolerances.
301
306
  latents = torch.from_numpy(
302
- np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
307
+ np.random.normal(size=(1, 4, 64, 64), scale=0.1).astype(np.float32)
303
308
  )
304
309
 
305
310
  pytorch_model = sd_decoder.Decoder(config).eval()
@@ -316,10 +321,10 @@ class TestModelConversion(googletest.TestCase):
316
321
  signature_name="decode",
317
322
  )
318
323
  self.assertTrue(
319
- np.allclose(
324
+ test_utils.compare_logits(
320
325
  edge_output,
321
326
  torch_output.detach().numpy(),
322
- atol=1e-4,
327
+ atol=1e-3,
323
328
  rtol=1e-5,
324
329
  )
325
330
  )
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Common utils for testing."""
17
17
 
18
+ import logging
19
+
18
20
  from ai_edge_torch import model
19
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
22
  from ai_edge_torch.lowertools import common_utils
@@ -33,7 +35,7 @@ def compare_tflite_torch(
33
35
  atol: float = 1e-5,
34
36
  rtol: float = 1e-5,
35
37
  **kwargs,
36
- ):
38
+ ) -> bool:
37
39
  """Compares torch models and TFLite models."""
38
40
  values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
39
41
  flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
@@ -49,9 +51,32 @@ def compare_tflite_torch(
49
51
  **kwargs,
50
52
  )
51
53
 
52
- return np.allclose(
53
- edge_output["logits"],
54
- torch_output["logits"].detach().numpy(),
55
- atol=atol,
56
- rtol=rtol,
54
+ return compare_logits(
55
+ edge_output["logits"], torch_output["logits"].detach().numpy(), atol, rtol
57
56
  )
57
+
58
+
59
+ def compare_logits(
60
+ edge_logits: np.ndarray,
61
+ torch_logits: dict[str, torch.Tensor],
62
+ atol: float = 1e-5,
63
+ rtol: float = 1e-5,
64
+ ) -> bool:
65
+ """Compares logits from edge model and torch model."""
66
+ if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
67
+ return True
68
+
69
+ logging.info("edge_logits: %s", edge_logits)
70
+ logging.info("torch_logits: %s", torch_logits)
71
+
72
+ orig_atol = atol
73
+ while rtol < 1:
74
+ atol = orig_atol
75
+ while atol < 1:
76
+ if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
77
+ logging.info("Got allclose true with atol=%s, rtol=%s", atol, rtol)
78
+ return False
79
+ atol *= 10
80
+ rtol *= 10
81
+ logging.info("allclose failed with reasonable atol and rtol.")
82
+ return False
@@ -15,13 +15,28 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from typing import Union
18
+ from functools import partial
19
+ from typing import Any, Union
19
20
 
20
21
  from ai_edge_torch._convert import converter as converter_utils
21
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
22
23
  import ai_edge_torch.generative.layers.model_config as cfg
23
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
24
26
  import torch
27
+ import torch.nn as nn
28
+
29
+
30
+ class ExportableModule(torch.nn.Module):
31
+
32
+ def __init__(self, module, **extra_kwargs):
33
+ super().__init__()
34
+ self.module = module
35
+ self.extra_kwargs = extra_kwargs
36
+
37
+ def forward(self, *export_args, **export_kwargs):
38
+ full_kwargs = {**export_kwargs, **self.extra_kwargs}
39
+ return self.module(*export_args, **full_kwargs)
25
40
 
26
41
 
27
42
  def convert_to_tflite(
@@ -31,6 +46,7 @@ def convert_to_tflite(
31
46
  pixel_values_size: torch.Size = None,
32
47
  quantize: bool = True,
33
48
  config: cfg.ModelConfig = None,
49
+ export_config: ExportConfig = None,
34
50
  ):
35
51
  """Converts a nn.Module model to multi-signature tflite model.
36
52
 
@@ -97,6 +113,11 @@ def convert_to_tflite(
97
113
  )
98
114
 
99
115
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
+
117
+ # For export, we create a module that captures any non-exportable,
118
+ # arugments, e.g. the generation config object.
119
+ mod = ExportableModule(pytorch_model, export_config=export_config)
120
+
100
121
  converter = converter_utils.Converter()
101
122
  for i in range(len(prefill_seq_lens)):
102
123
  prefill_seq_len = prefill_seq_lens[i]
@@ -108,7 +129,7 @@ def convert_to_tflite(
108
129
  prefill_signature_name = f'prefill_{prefill_seq_len}'
109
130
  converter.add_signature(
110
131
  prefill_signature_name,
111
- pytorch_model,
132
+ mod,
112
133
  sample_kwargs={
113
134
  'tokens': prefill_tokens,
114
135
  'input_pos': prefill_input_pos,
@@ -118,7 +139,7 @@ def convert_to_tflite(
118
139
  if prefill_pixel_values is not None:
119
140
  converter.add_signature(
120
141
  prefill_signature_name + '_pixel',
121
- pytorch_model,
142
+ mod,
122
143
  sample_kwargs={
123
144
  'tokens': prefill_tokens,
124
145
  'input_pos': prefill_input_pos,
@@ -129,7 +150,7 @@ def convert_to_tflite(
129
150
 
130
151
  converter.add_signature(
131
152
  'decode',
132
- pytorch_model,
153
+ mod,
133
154
  sample_kwargs={
134
155
  'tokens': decode_token,
135
156
  'input_pos': decode_input_pos,
@@ -16,7 +16,8 @@
16
16
  """Utilities to be used for re-authoring transformer models."""
17
17
 
18
18
  import copy
19
- from typing import Tuple
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
45
46
  TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
46
47
 
47
48
 
49
+ @dataclass
50
+ class ExportConfig:
51
+ """Model generating configuration settings."""
52
+
53
+ # On prefill signatures, should the model produce logit output?
54
+ # When False, only decode signatures will produce output.
55
+ output_logits_on_prefill: bool = False
56
+
57
+
48
58
  class DecoderOnlyModel(nn.Module):
49
59
  """A simple decoder-only transformer model built from the Edge Generative API.
50
60
 
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
93
103
  tokens: torch.Tensor,
94
104
  input_pos: torch.Tensor,
95
105
  kv_cache: kv_utils.KVCache,
106
+ export_config: Optional[ExportConfig] = None,
96
107
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
108
  _, seq_len = tokens.size()
98
109
  assert self.config.max_seq_len >= seq_len, (
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
108
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
109
120
 
110
121
  return self.forward_with_embeds(
111
- input_embeds, rope, mask, input_pos, kv_cache
122
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
112
123
  )
113
124
 
114
125
  def forward_with_embeds(
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
118
129
  mask: torch.Tensor,
119
130
  input_pos: torch.Tensor,
120
131
  kv_cache: kv_utils.KVCache,
132
+ export_config: Optional[ExportConfig] = None,
121
133
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
122
134
  """Forwards the model with input embeddings."""
123
135
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
137
149
  updated_kv_entires.append(kv_entry)
138
150
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
139
151
 
152
+ if export_config is not None:
153
+ if (
154
+ torch.numel(input_pos) > 1
155
+ and not export_config.output_logits_on_prefill
156
+ ):
157
+ return {"kv_cache": updated_kv_cache}
158
+
140
159
  x = self.final_norm(x)
141
160
  logits = self.lm_head(x) # (b, t, vocab_size)
142
161
  return {"logits": logits, "kv_cache": updated_kv_cache}
@@ -146,8 +165,9 @@ def build_decoder_only_model(
146
165
  checkpoint_path: str,
147
166
  config: cfg.ModelConfig,
148
167
  tensor_names: loading_utils.ModelLoader.TensorNames,
149
- ) -> DecoderOnlyModel:
150
- transformer = DecoderOnlyModel(config)
168
+ model_class: type[nn.Module] = DecoderOnlyModel,
169
+ ) -> nn.Module:
170
+ transformer = model_class(config)
151
171
  loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
152
172
  loader.load(
153
173
  transformer, strict=not config.lm_head_share_weight_with_embedding
@@ -19,6 +19,7 @@ import logging
19
19
  from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
22
23
  import torch
23
24
 
24
25
 
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
40
41
  """
41
42
  super().__init__()
42
43
  self.model = model
44
+ self.export_config = ExportConfig(output_logits_on_prefill=True)
43
45
 
44
46
  def forward(
45
47
  self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
103
105
  Returns:
104
106
  The output logits and the updated KV cache.
105
107
  """
108
+ # Verification requires logit outputs on prefill for comparison.
109
+ if (
110
+ self.export_config is not None
111
+ and not self.export_config.output_logits_on_prefill
112
+ ):
113
+ raise ValueError("Verifier requires logit output on prefill.")
106
114
  # Since the reauthored model doesn't include keyword arguments, pass
107
115
  # pixel_values only when it is not None. Otherwise, it may raise an error.
108
116
  if pixel_values is None:
109
- output = self.model.forward(tokens, input_pos, kv_cache)
117
+ output = self.model.forward(
118
+ tokens, input_pos, kv_cache, export_config=self.export_config
119
+ )
110
120
  else:
111
121
  output = self.model.forward(
112
- tokens, input_pos, kv_cache, pixel_values=pixel_values
122
+ tokens,
123
+ input_pos,
124
+ kv_cache,
125
+ pixel_values=pixel_values,
126
+ export_config=self.export_config,
113
127
  )
114
128
  return output["logits"], output["kv_cache"]
115
129
 
@@ -21,6 +21,6 @@ from . import _quantized_decomposed
21
21
  from . import context
22
22
  from . import registry
23
23
  from . import utils
24
- from .registry import decompositions
24
+ from .decomp import decompositions
25
25
  from .registry import lookup
26
26
  from .registry import lower