ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ai_edge_torch/__init__.py +1 -1
  2. ai_edge_torch/_config.py +52 -0
  3. ai_edge_torch/_convert/test/test_convert.py +1 -2
  4. ai_edge_torch/debug/test/test_culprit.py +8 -3
  5. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  6. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  7. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  8. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  9. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  10. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  11. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  12. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  15. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  16. ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
  17. ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
  18. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  19. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  20. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  21. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  22. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  23. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  24. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  25. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  26. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  27. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  29. ai_edge_torch/generative/layers/attention.py +2 -6
  30. ai_edge_torch/generative/layers/kv_cache.py +24 -18
  31. ai_edge_torch/generative/layers/normalization.py +1 -3
  32. ai_edge_torch/generative/test/test_kv_cache.py +3 -3
  33. ai_edge_torch/generative/test/test_model_conversion.py +12 -14
  34. ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
  35. ai_edge_torch/generative/test/utils.py +31 -6
  36. ai_edge_torch/generative/utilities/converter.py +25 -4
  37. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  38. ai_edge_torch/generative/utilities/verifier.py +16 -2
  39. ai_edge_torch/lowertools/_shim.py +4 -2
  40. ai_edge_torch/lowertools/test_utils.py +4 -2
  41. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  42. ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
  43. ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
  44. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
  45. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
  46. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
  47. ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
  48. ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
  49. ai_edge_torch/version.py +1 -1
  50. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
  51. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
  52. ai_edge_torch/config.py +0 -27
  53. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
  54. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
  55. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
  56. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES
22
23
 
23
24
 
25
+ class SmolLM(model_builder.DecoderOnlyModel):
26
+ """A SmolLM model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a SmolLM 135M model.
26
32
 
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
72
78
  return config
73
79
 
74
80
 
75
- def build_model(
76
- checkpoint_path: str, **kwargs
77
- ) -> model_builder.DecoderOnlyModel:
81
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
78
82
  return model_builder.build_decoder_only_model(
79
83
  checkpoint_path=checkpoint_path,
80
84
  config=get_model_config(**kwargs),
81
85
  tensor_names=TENSOR_NAMES,
86
+ model_class=SmolLM,
82
87
  )
@@ -15,13 +15,14 @@
15
15
 
16
16
  """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
17
 
18
- from typing import Tuple
18
+ from typing import Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
  import torch
26
27
  from torch import nn
27
28
 
@@ -62,6 +63,7 @@ class ToyModelWithKVCache(torch.nn.Module):
62
63
  tokens: torch.Tensor,
63
64
  input_pos: torch.Tensor,
64
65
  kv_cache: kv_utils.KVCache,
66
+ export_config: Optional[ExportConfig] = None,
65
67
  ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
66
68
  x = self.tok_embedding(tokens)
67
69
  cos, sin = self.rope_cache
@@ -77,8 +79,16 @@ class ToyModelWithKVCache(torch.nn.Module):
77
79
  if kv_entry:
78
80
  updated_kv_entires.append(kv_entry)
79
81
 
80
- x = self.final_norm(x)
81
82
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
83
+
84
+ if export_config is not None:
85
+ if (
86
+ torch.numel(input_pos) > 1
87
+ and not export_config.output_logits_on_prefill
88
+ ):
89
+ return {'kv_cache': updated_kv_cache}
90
+
91
+ x = self.final_norm(x)
82
92
  return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
83
93
 
84
94
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -63,6 +64,7 @@ def main(_):
63
64
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
64
65
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
66
  quantize=_QUANTIZE.value,
67
+ export_config=ExportConfig(),
66
68
  )
67
69
 
68
70
 
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
22
23
 
23
24
 
25
+ class TinyLlama(model_builder.DecoderOnlyModel):
26
+ """A TinyLlama model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a TinyLlama model.
26
32
 
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
73
79
  return config
74
80
 
75
81
 
76
- def build_model(
77
- checkpoint_path: str, **kwargs
78
- ) -> model_builder.DecoderOnlyModel:
82
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
79
83
  return model_builder.build_decoder_only_model(
80
84
  checkpoint_path=checkpoint_path,
81
85
  config=get_model_config(**kwargs),
82
86
  tensor_names=TENSOR_NAMES,
87
+ model_class=TinyLlama,
83
88
  )
@@ -241,9 +241,7 @@ class CausalSelfAttention(nn.Module):
241
241
  q, k = _embed_rope(q, k, n_elem, rope)
242
242
 
243
243
  if kv_cache is not None:
244
- kv_cache = kv_utils.update(
245
- kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
246
- )
244
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
247
245
  k, v = kv_cache.k_cache, kv_cache.v_cache
248
246
 
249
247
  y = self.sdpa_func(
@@ -379,9 +377,7 @@ class CrossAttention(nn.Module):
379
377
  q, k = _embed_rope(q, k, n_elem, rope)
380
378
 
381
379
  if kv_cache is not None:
382
- kv_cache = kv_utils.update(
383
- kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
384
- )
380
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
385
381
  k, v = kv_cache.k_cache, kv_cache.v_cache
386
382
  if mask is None:
387
383
  mask = torch.zeros(
@@ -20,6 +20,7 @@ from typing import List, Tuple
20
20
 
21
21
  from ai_edge_torch import hlfb
22
22
  from ai_edge_torch.generative.layers import model_config
23
+ from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
23
24
  import torch
24
25
  import torch.utils._pytree as pytree
25
26
 
@@ -146,7 +147,7 @@ def update(
146
147
  input_pos: torch.Tensor,
147
148
  k_slice: torch.Tensor,
148
149
  v_slice: torch.Tensor,
149
- enable_hlfb: bool = True,
150
+ use_dus: bool = True,
150
151
  ) -> KVCacheEntry:
151
152
  """Out of place update of Cache buffer.
152
153
 
@@ -155,17 +156,12 @@ def update(
155
156
  input_pos (torch.Tensor): The update slice positions.
156
157
  k_slice (torch.Tensor): The K slice to be updated in the new cache.
157
158
  v_slice (torch.Tensor): The V slice to be updated in the new cache.
158
- enable_hlfb (bool, optional): Whether the op is annotated for export with
159
- High Level Function Boundary. Defaults to True.
160
159
 
161
160
  Returns:
162
161
  KVCacheEntry: The updated KVCache entry based on the passed inputs.
163
162
  """
164
- # Don't enable HLFB for kv cache op for now, since it won't work with LLM
165
- # inference engine. Remove this part once we ship a new LLM inference engine.
166
- enable_hlfb=False
167
- update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
168
- return update_func(cache, input_pos, k_slice, v_slice)
163
+ update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
164
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
169
165
 
170
166
 
171
167
  def _update_kv_base_impl(
@@ -181,18 +177,28 @@ def _update_kv_base_impl(
181
177
  return updated_cache
182
178
 
183
179
 
184
- def _update_kv_hlfb_impl(
180
+ def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
181
+ """Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
182
+
183
+ zero = torch.zeros([]).int()
184
+ positions = positions.int()[0].reshape([])
185
+ return [zero, positions, zero, zero]
186
+
187
+
188
+ def _update_kv_impl(
185
189
  cache: KVCacheEntry,
186
190
  input_pos: torch.Tensor,
187
191
  k_slice: torch.Tensor,
188
192
  v_slice: torch.Tensor,
189
193
  ) -> KVCacheEntry:
190
- """Update the cache buffer with High Level Function Boundary annotation."""
191
- builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
192
- k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
193
- cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
194
- )
195
- k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
196
- v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
197
- k, v = builder.mark_outputs(k, v)
198
- return KVCacheEntry(k, v)
194
+ """Update the cache buffer for K and V caches."""
195
+ # NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
196
+
197
+ k_slice_indices = _get_slice_indices(input_pos)
198
+ v_slice_indices = _get_slice_indices(input_pos)
199
+
200
+ k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
201
+ v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
202
+
203
+ updated_cache = KVCacheEntry(k, v)
204
+ return updated_cache
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
190
190
  """
191
191
  x = torch.permute(x, (0, 2, 3, 1))
192
192
 
193
- # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
194
- # int32 when the bug is fixed.
195
193
  builder = StableHLOCompositeBuilder(
196
194
  name="odml.group_norm",
197
195
  attr={
198
196
  "num_groups": num_groups,
199
197
  "epsilon": eps,
200
- "reduction_axes": 3,
198
+ "reduction_axes": [3],
201
199
  "channel_axis": 3,
202
200
  },
203
201
  )
@@ -71,18 +71,18 @@ class TestKVLayers(googletest.TestCase):
71
71
  [0, 0, 5, 5, 0, 0, 0, 0],
72
72
  )
73
73
  # multi-slice update
74
- input_pos = torch.tensor([0, 3])
74
+ input_pos = torch.tensor([0, 1])
75
75
  k_slice = v_slice = torch.full(
76
76
  (1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
77
77
  )
78
78
  updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
79
79
  self.assertEqual(
80
80
  updated_entry.k_cache.numpy().flatten().tolist(),
81
- [7, 7, 0, 0, 0, 0, 7, 7],
81
+ [7, 7, 7, 7, 0, 0, 0, 0],
82
82
  )
83
83
  self.assertEqual(
84
84
  updated_entry.v_cache.numpy().flatten().tolist(),
85
- [7, 7, 0, 0, 0, 0, 7, 7],
85
+ [7, 7, 7, 7, 0, 0, 0, 0],
86
86
  )
87
87
 
88
88
  def test_serialization(self):
@@ -16,12 +16,10 @@
16
16
  """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch import config as ai_edge_config
20
19
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
21
  from ai_edge_torch.generative.layers import kv_cache
23
22
  from ai_edge_torch.generative.test import utils as test_utils
24
- from ai_edge_torch.generative.utilities import model_builder
25
23
  import numpy as np
26
24
  import torch
27
25
 
@@ -84,25 +82,25 @@ class TestModelConversion(googletest.TestCase):
84
82
  )
85
83
 
86
84
  @googletest.skipIf(
87
- ai_edge_config.Config.use_torch_xla,
88
- reason="tests with custom ops are not supported on oss",
85
+ ai_edge_torch.config.in_oss,
86
+ reason="tests with custom ops are not supported in oss",
89
87
  )
90
88
  def test_toy_model_with_kv_cache(self):
91
89
  self._test_model_with_kv_cache(enable_hlfb=False)
92
90
 
93
91
  @googletest.skipIf(
94
- ai_edge_config.Config.use_torch_xla,
95
- reason="tests with custom ops are not supported on oss",
92
+ ai_edge_torch.config.in_oss,
93
+ reason="tests with custom ops are not supported in oss",
96
94
  )
97
95
  def test_toy_model_with_kv_cache_with_hlfb(self):
98
96
  self._test_model_with_kv_cache(enable_hlfb=True)
99
97
 
100
98
  @googletest.skipIf(
101
- ai_edge_config.Config.use_torch_xla,
102
- reason="tests with custom ops are not supported on oss",
99
+ ai_edge_torch.config.in_oss,
100
+ reason="tests with custom ops are not supported in oss",
103
101
  )
104
- def test_toy_model_has_ekv_op(self):
105
- """Tests that the model has the external kv cache op."""
102
+ def test_toy_model_has_dus_op(self):
103
+ """Tests that the model has the dynamic update slice op."""
106
104
  _, edge_model, _ = self._get_params(enable_hlfb=True)
107
105
  interpreter_ = interpreter.InterpreterWithCustomOps(
108
106
  custom_op_registerers=["GenAIOpsRegisterer"],
@@ -112,7 +110,7 @@ class TestModelConversion(googletest.TestCase):
112
110
 
113
111
  # pylint: disable=protected-access
114
112
  op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
115
- self.assertIn("odml.update_external_kv_cache", op_names)
113
+ self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
116
114
 
117
115
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
118
116
  # prefill
@@ -180,12 +178,12 @@ class TestModelConversion(googletest.TestCase):
180
178
  )
181
179
 
182
180
  @googletest.skipIf(
183
- ai_edge_config.Config.use_torch_xla,
184
- reason="tests with custom ops are not supported on oss",
181
+ ai_edge_torch.config.in_oss,
182
+ reason="tests with custom ops are not supported in oss",
185
183
  )
186
184
  def test_tiny_llama_multisig(self):
187
185
  config = tiny_llama.get_fake_model_config()
188
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
186
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
189
187
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
190
188
 
191
189
 
@@ -16,7 +16,6 @@
16
16
  """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch import config as ai_edge_config
20
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
21
20
  from ai_edge_torch.generative.examples.gemma import gemma1
22
21
  from ai_edge_torch.generative.examples.gemma import gemma2
@@ -32,7 +31,6 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_dec
32
31
  from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
33
32
  from ai_edge_torch.generative.layers import kv_cache
34
33
  from ai_edge_torch.generative.test import utils as test_utils
35
- from ai_edge_torch.generative.utilities import model_builder
36
34
  import numpy as np
37
35
  import torch
38
36
 
@@ -53,12 +51,15 @@ class TestModelConversion(googletest.TestCase):
53
51
  experimental_default_delegate_latest_features=True,
54
52
  )
55
53
  )
54
+ # Default cache_size_limit, 8 is hit and aborts often when the tests are
55
+ # running all together. Doubles it to avoid abortion.
56
+ torch._dynamo.config.cache_size_limit = 16
57
+ np.random.seed(1234) # Make np.random deterministic.
56
58
 
57
59
  def _test_model(self, config, model, signature_name, atol, rtol):
58
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
59
- tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
60
- tokens[0, :4] = idx
61
- input_pos = torch.arange(0, 10, dtype=torch.int)
60
+ seq_len = 10
61
+ tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
62
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
62
63
  kv = kv_cache.KVCache.from_model_config(config)
63
64
 
64
65
  edge_model = ai_edge_torch.signature(
@@ -74,6 +75,7 @@ class TestModelConversion(googletest.TestCase):
74
75
  self._interpreter_builder(edge_model.tflite_model())
75
76
  )
76
77
 
78
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
77
79
  self.assertTrue(
78
80
  test_utils.compare_tflite_torch(
79
81
  edge_model,
@@ -88,19 +90,17 @@ class TestModelConversion(googletest.TestCase):
88
90
  )
89
91
 
90
92
  @googletest.skipIf(
91
- ai_edge_config.Config.use_torch_xla,
92
- reason="tests with custom ops are not supported on oss",
93
+ ai_edge_torch.config.in_oss,
94
+ reason="tests with custom ops are not supported in oss",
93
95
  )
94
96
  def test_gemma1(self):
95
97
  config = gemma1.get_fake_model_config()
96
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
97
- self._test_model(
98
- config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
99
- )
98
+ pytorch_model = gemma1.Gemma1(config).eval()
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
100
100
 
101
101
  @googletest.skipIf(
102
- ai_edge_config.Config.use_torch_xla,
103
- reason="tests with custom ops are not supported on oss",
102
+ ai_edge_torch.config.in_oss,
103
+ reason="tests with custom ops are not supported in oss",
104
104
  )
105
105
  def test_gemma2(self):
106
106
  config = gemma2.get_fake_model_config()
@@ -108,8 +108,8 @@ class TestModelConversion(googletest.TestCase):
108
108
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
109
109
 
110
110
  @googletest.skipIf(
111
- ai_edge_config.Config.use_torch_xla,
112
- reason="tests with custom ops are not supported on oss",
111
+ ai_edge_torch.config.in_oss,
112
+ reason="tests with custom ops are not supported in oss",
113
113
  )
114
114
  def test_llama(self):
115
115
  config = llama.get_fake_model_config()
@@ -117,19 +117,18 @@ class TestModelConversion(googletest.TestCase):
117
117
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
118
118
 
119
119
  @googletest.skipIf(
120
- ai_edge_config.Config.use_torch_xla,
121
- reason="tests with custom ops are not supported on oss",
120
+ ai_edge_torch.config.in_oss,
121
+ reason="tests with custom ops are not supported in oss",
122
122
  )
123
123
  def test_phi2(self):
124
124
  config = phi2.get_fake_model_config()
125
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
126
- self._test_model(
127
- config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
128
- )
125
+ pytorch_model = phi2.Phi2(config).eval()
126
+ # Phi-2 logits are very big, so we need a larger absolute tolerance.
127
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
129
128
 
130
129
  @googletest.skipIf(
131
- ai_edge_config.Config.use_torch_xla,
132
- reason="tests with custom ops are not supported on oss",
130
+ ai_edge_torch.config.in_oss,
131
+ reason="tests with custom ops are not supported in oss",
133
132
  )
134
133
  def test_phi3(self):
135
134
  config = phi3.get_fake_model_config()
@@ -137,58 +136,58 @@ class TestModelConversion(googletest.TestCase):
137
136
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
138
137
 
139
138
  @googletest.skipIf(
140
- ai_edge_config.Config.use_torch_xla,
141
- reason="tests with custom ops are not supported on oss",
139
+ ai_edge_torch.config.in_oss,
140
+ reason="tests with custom ops are not supported in oss",
142
141
  )
143
142
  def test_smollm(self):
144
143
  config = smollm.get_fake_model_config()
145
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
144
+ pytorch_model = smollm.SmolLM(config).eval()
146
145
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
147
146
 
148
147
  @googletest.skipIf(
149
- ai_edge_config.Config.use_torch_xla,
150
- reason="tests with custom ops are not supported on oss",
148
+ ai_edge_torch.config.in_oss,
149
+ reason="tests with custom ops are not supported in oss",
151
150
  )
152
151
  def test_openelm(self):
153
152
  config = openelm.get_fake_model_config()
154
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
153
+ pytorch_model = openelm.OpenELM(config).eval()
155
154
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
155
 
157
156
  @googletest.skipIf(
158
- ai_edge_config.Config.use_torch_xla,
159
- reason="tests with custom ops are not supported on oss",
157
+ ai_edge_torch.config.in_oss,
158
+ reason="tests with custom ops are not supported in oss",
160
159
  )
161
160
  def test_qwen(self):
162
161
  config = qwen.get_fake_model_config()
163
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
162
+ pytorch_model = qwen.Qwen(config).eval()
164
163
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
165
164
 
166
165
  @googletest.skipIf(
167
- ai_edge_config.Config.use_torch_xla,
168
- reason="tests with custom ops are not supported on oss",
166
+ ai_edge_torch.config.in_oss,
167
+ reason="tests with custom ops are not supported in oss",
169
168
  )
170
169
  def test_amd_llama_135m(self):
171
170
  config = amd_llama_135m.get_fake_model_config()
172
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
173
- self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
171
+ pytorch_model = amd_llama_135m.AmdLlama(config).eval()
172
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
174
173
 
175
174
  @googletest.skipIf(
176
- ai_edge_config.Config.use_torch_xla,
177
- reason="tests with custom ops are not supported on oss",
175
+ ai_edge_torch.config.in_oss,
176
+ reason="tests with custom ops are not supported in oss",
178
177
  )
179
- def test_paligemma(self):
178
+ def disabled_test_paligemma(self):
180
179
  config = paligemma.get_fake_model_config()
181
180
  pytorch_model = paligemma.PaliGemma(config).eval()
182
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
181
+
183
182
  image_embedding_config = config.image_encoder_config.image_embedding
184
183
  num_patches = (
185
184
  image_embedding_config.image_size // image_embedding_config.patch_size
186
185
  ) ** 2
186
+
187
187
  # Make sure the token size is longer than the number of image patches.
188
- tokens_len = num_patches + 10
189
- tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
190
- tokens[0, :4] = idx
191
- input_pos = torch.arange(0, tokens_len, dtype=torch.int)
188
+ seq_len = num_patches + 10
189
+ tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
190
+ input_pos = torch.arange(0, seq_len, dtype=torch.int)
192
191
  kv = kv_cache.KVCache.from_model_config(config.decoder_config)
193
192
  pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
194
193
 
@@ -206,6 +205,7 @@ class TestModelConversion(googletest.TestCase):
206
205
  self._interpreter_builder(edge_model.tflite_model())
207
206
  )
208
207
 
208
+ tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
209
209
  self.assertTrue(
210
210
  test_utils.compare_tflite_torch(
211
211
  edge_model,
@@ -221,8 +221,8 @@ class TestModelConversion(googletest.TestCase):
221
221
  )
222
222
 
223
223
  @googletest.skipIf(
224
- ai_edge_config.Config.use_torch_xla,
225
- reason="tests with custom ops are not supported on oss",
224
+ ai_edge_torch.config.in_oss,
225
+ reason="tests with custom ops are not supported in oss",
226
226
  )
227
227
  def test_stable_diffusion_clip(self):
228
228
  config = sd_clip.get_fake_model_config()
@@ -244,7 +244,7 @@ class TestModelConversion(googletest.TestCase):
244
244
  signature_name="encode",
245
245
  )
246
246
  self.assertTrue(
247
- np.allclose(
247
+ test_utils.compare_logits(
248
248
  edge_output,
249
249
  torch_output.detach().numpy(),
250
250
  atol=1e-4,
@@ -253,19 +253,21 @@ class TestModelConversion(googletest.TestCase):
253
253
  )
254
254
 
255
255
  @googletest.skipIf(
256
- ai_edge_config.Config.use_torch_xla,
257
- reason="tests with custom ops are not supported on oss",
256
+ ai_edge_torch.config.in_oss,
257
+ reason="tests with custom ops are not supported in oss",
258
258
  )
259
259
  def test_stable_diffusion_diffusion(self):
260
260
  config = sd_diffusion.get_fake_model_config(2)
261
+ # Reduce stddev(scale) of input values to avoid too big output logits which
262
+ # fails comparisons with reasonable tolerances.
261
263
  latents = torch.from_numpy(
262
- np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
264
+ np.random.normal(size=(2, 4, 8, 8), scale=0.1).astype(np.float32)
263
265
  )
264
266
  context = torch.from_numpy(
265
- np.random.normal(size=(2, 4, 4)).astype(np.float32)
267
+ np.random.normal(size=(2, 4, 4), scale=0.1).astype(np.float32)
266
268
  )
267
269
  time_embedding = torch.from_numpy(
268
- np.random.normal(size=(2, 2)).astype(np.float32)
270
+ np.random.normal(size=(2, 2), scale=0.1).astype(np.float32)
269
271
  )
270
272
 
271
273
  pytorch_model = sd_diffusion.Diffusion(config).eval()
@@ -284,7 +286,7 @@ class TestModelConversion(googletest.TestCase):
284
286
  signature_name="diffusion",
285
287
  )
286
288
  self.assertTrue(
287
- np.allclose(
289
+ test_utils.compare_logits(
288
290
  edge_output,
289
291
  torch_output.detach().numpy(),
290
292
  atol=1e-4,
@@ -293,13 +295,15 @@ class TestModelConversion(googletest.TestCase):
293
295
  )
294
296
 
295
297
  @googletest.skipIf(
296
- ai_edge_config.Config.use_torch_xla,
297
- reason="tests with custom ops are not supported on oss",
298
+ ai_edge_torch.config.in_oss,
299
+ reason="tests with custom ops are not supported in oss",
298
300
  )
299
301
  def test_stable_diffusion_decoder(self):
300
302
  config = sd_decoder.get_fake_model_config()
303
+ # Reduce stddev(scale) of input values to avoid too big output logits which
304
+ # fails comparisons with reasonable tolerances.
301
305
  latents = torch.from_numpy(
302
- np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
306
+ np.random.normal(size=(1, 4, 64, 64), scale=0.1).astype(np.float32)
303
307
  )
304
308
 
305
309
  pytorch_model = sd_decoder.Decoder(config).eval()
@@ -316,10 +320,10 @@ class TestModelConversion(googletest.TestCase):
316
320
  signature_name="decode",
317
321
  )
318
322
  self.assertTrue(
319
- np.allclose(
323
+ test_utils.compare_logits(
320
324
  edge_output,
321
325
  torch_output.detach().numpy(),
322
- atol=1e-4,
326
+ atol=1e-3,
323
327
  rtol=1e-5,
324
328
  )
325
329
  )
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Common utils for testing."""
17
17
 
18
+ import logging
19
+
18
20
  from ai_edge_torch import model
19
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
22
  from ai_edge_torch.lowertools import common_utils
@@ -33,7 +35,7 @@ def compare_tflite_torch(
33
35
  atol: float = 1e-5,
34
36
  rtol: float = 1e-5,
35
37
  **kwargs,
36
- ):
38
+ ) -> bool:
37
39
  """Compares torch models and TFLite models."""
38
40
  values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
39
41
  flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
@@ -49,9 +51,32 @@ def compare_tflite_torch(
49
51
  **kwargs,
50
52
  )
51
53
 
52
- return np.allclose(
53
- edge_output["logits"],
54
- torch_output["logits"].detach().numpy(),
55
- atol=atol,
56
- rtol=rtol,
54
+ return compare_logits(
55
+ edge_output["logits"], torch_output["logits"].detach().numpy(), atol, rtol
57
56
  )
57
+
58
+
59
+ def compare_logits(
60
+ edge_logits: np.ndarray,
61
+ torch_logits: dict[str, torch.Tensor],
62
+ atol: float = 1e-5,
63
+ rtol: float = 1e-5,
64
+ ) -> bool:
65
+ """Compares logits from edge model and torch model."""
66
+ if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
67
+ return True
68
+
69
+ logging.info("edge_logits: %s", edge_logits)
70
+ logging.info("torch_logits: %s", torch_logits)
71
+
72
+ orig_atol = atol
73
+ while rtol < 1:
74
+ atol = orig_atol
75
+ while atol < 1:
76
+ if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
77
+ logging.info("Got allclose true with atol=%s, rtol=%s", atol, rtol)
78
+ return False
79
+ atol *= 10
80
+ rtol *= 10
81
+ logging.info("allclose failed with reasonable atol and rtol.")
82
+ return False