ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +24 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +12 -14
- ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
22
23
|
|
23
24
|
|
25
|
+
class SmolLM(model_builder.DecoderOnlyModel):
|
26
|
+
"""A SmolLM model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for a SmolLM 135M model.
|
26
32
|
|
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
72
78
|
return config
|
73
79
|
|
74
80
|
|
75
|
-
def build_model(
|
76
|
-
checkpoint_path: str, **kwargs
|
77
|
-
) -> model_builder.DecoderOnlyModel:
|
81
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
78
82
|
return model_builder.build_decoder_only_model(
|
79
83
|
checkpoint_path=checkpoint_path,
|
80
84
|
config=get_model_config(**kwargs),
|
81
85
|
tensor_names=TENSOR_NAMES,
|
86
|
+
model_class=SmolLM,
|
82
87
|
)
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""A toy example which has basic transformer block (w/ externalized KV-Cache)."""
|
17
17
|
|
18
|
-
from typing import Tuple
|
18
|
+
from typing import Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
import torch
|
26
27
|
from torch import nn
|
27
28
|
|
@@ -62,6 +63,7 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
62
63
|
tokens: torch.Tensor,
|
63
64
|
input_pos: torch.Tensor,
|
64
65
|
kv_cache: kv_utils.KVCache,
|
66
|
+
export_config: Optional[ExportConfig] = None,
|
65
67
|
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
|
66
68
|
x = self.tok_embedding(tokens)
|
67
69
|
cos, sin = self.rope_cache
|
@@ -77,8 +79,16 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
77
79
|
if kv_entry:
|
78
80
|
updated_kv_entires.append(kv_entry)
|
79
81
|
|
80
|
-
x = self.final_norm(x)
|
81
82
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
83
|
+
|
84
|
+
if export_config is not None:
|
85
|
+
if (
|
86
|
+
torch.numel(input_pos) > 1
|
87
|
+
and not export_config.output_logits_on_prefill
|
88
|
+
):
|
89
|
+
return {'kv_cache': updated_kv_cache}
|
90
|
+
|
91
|
+
x = self.final_norm(x)
|
82
92
|
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
83
93
|
|
84
94
|
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -63,6 +64,7 @@ def main(_):
|
|
63
64
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
64
65
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
66
|
quantize=_QUANTIZE.value,
|
67
|
+
export_config=ExportConfig(),
|
66
68
|
)
|
67
69
|
|
68
70
|
|
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
22
23
|
|
23
24
|
|
25
|
+
class TinyLlama(model_builder.DecoderOnlyModel):
|
26
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for a TinyLlama model.
|
26
32
|
|
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
73
79
|
return config
|
74
80
|
|
75
81
|
|
76
|
-
def build_model(
|
77
|
-
checkpoint_path: str, **kwargs
|
78
|
-
) -> model_builder.DecoderOnlyModel:
|
82
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
79
83
|
return model_builder.build_decoder_only_model(
|
80
84
|
checkpoint_path=checkpoint_path,
|
81
85
|
config=get_model_config(**kwargs),
|
82
86
|
tensor_names=TENSOR_NAMES,
|
87
|
+
model_class=TinyLlama,
|
83
88
|
)
|
@@ -241,9 +241,7 @@ class CausalSelfAttention(nn.Module):
|
|
241
241
|
q, k = _embed_rope(q, k, n_elem, rope)
|
242
242
|
|
243
243
|
if kv_cache is not None:
|
244
|
-
kv_cache = kv_utils.update(
|
245
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
246
|
-
)
|
244
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
247
245
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
248
246
|
|
249
247
|
y = self.sdpa_func(
|
@@ -379,9 +377,7 @@ class CrossAttention(nn.Module):
|
|
379
377
|
q, k = _embed_rope(q, k, n_elem, rope)
|
380
378
|
|
381
379
|
if kv_cache is not None:
|
382
|
-
kv_cache = kv_utils.update(
|
383
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
384
|
-
)
|
380
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
385
381
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
386
382
|
if mask is None:
|
387
383
|
mask = torch.zeros(
|
@@ -20,6 +20,7 @@ from typing import List, Tuple
|
|
20
20
|
|
21
21
|
from ai_edge_torch import hlfb
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
|
+
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
|
23
24
|
import torch
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
@@ -146,7 +147,7 @@ def update(
|
|
146
147
|
input_pos: torch.Tensor,
|
147
148
|
k_slice: torch.Tensor,
|
148
149
|
v_slice: torch.Tensor,
|
149
|
-
|
150
|
+
use_dus: bool = True,
|
150
151
|
) -> KVCacheEntry:
|
151
152
|
"""Out of place update of Cache buffer.
|
152
153
|
|
@@ -155,17 +156,12 @@ def update(
|
|
155
156
|
input_pos (torch.Tensor): The update slice positions.
|
156
157
|
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
157
158
|
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
158
|
-
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
159
|
-
High Level Function Boundary. Defaults to True.
|
160
159
|
|
161
160
|
Returns:
|
162
161
|
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
163
162
|
"""
|
164
|
-
|
165
|
-
|
166
|
-
enable_hlfb=False
|
167
|
-
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
|
168
|
-
return update_func(cache, input_pos, k_slice, v_slice)
|
163
|
+
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
|
164
|
+
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
169
165
|
|
170
166
|
|
171
167
|
def _update_kv_base_impl(
|
@@ -181,18 +177,28 @@ def _update_kv_base_impl(
|
|
181
177
|
return updated_cache
|
182
178
|
|
183
179
|
|
184
|
-
def
|
180
|
+
def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
|
181
|
+
"""Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
|
182
|
+
|
183
|
+
zero = torch.zeros([]).int()
|
184
|
+
positions = positions.int()[0].reshape([])
|
185
|
+
return [zero, positions, zero, zero]
|
186
|
+
|
187
|
+
|
188
|
+
def _update_kv_impl(
|
185
189
|
cache: KVCacheEntry,
|
186
190
|
input_pos: torch.Tensor,
|
187
191
|
k_slice: torch.Tensor,
|
188
192
|
v_slice: torch.Tensor,
|
189
193
|
) -> KVCacheEntry:
|
190
|
-
"""Update the cache buffer
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
)
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
194
|
+
"""Update the cache buffer for K and V caches."""
|
195
|
+
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
|
196
|
+
|
197
|
+
k_slice_indices = _get_slice_indices(input_pos)
|
198
|
+
v_slice_indices = _get_slice_indices(input_pos)
|
199
|
+
|
200
|
+
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
201
|
+
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
202
|
+
|
203
|
+
updated_cache = KVCacheEntry(k, v)
|
204
|
+
return updated_cache
|
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
|
|
190
190
|
"""
|
191
191
|
x = torch.permute(x, (0, 2, 3, 1))
|
192
192
|
|
193
|
-
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
194
|
-
# int32 when the bug is fixed.
|
195
193
|
builder = StableHLOCompositeBuilder(
|
196
194
|
name="odml.group_norm",
|
197
195
|
attr={
|
198
196
|
"num_groups": num_groups,
|
199
197
|
"epsilon": eps,
|
200
|
-
"reduction_axes": 3,
|
198
|
+
"reduction_axes": [3],
|
201
199
|
"channel_axis": 3,
|
202
200
|
},
|
203
201
|
)
|
@@ -71,18 +71,18 @@ class TestKVLayers(googletest.TestCase):
|
|
71
71
|
[0, 0, 5, 5, 0, 0, 0, 0],
|
72
72
|
)
|
73
73
|
# multi-slice update
|
74
|
-
input_pos = torch.tensor([0,
|
74
|
+
input_pos = torch.tensor([0, 1])
|
75
75
|
k_slice = v_slice = torch.full(
|
76
76
|
(1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
|
77
77
|
)
|
78
78
|
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
79
79
|
self.assertEqual(
|
80
80
|
updated_entry.k_cache.numpy().flatten().tolist(),
|
81
|
-
[7, 7,
|
81
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
82
82
|
)
|
83
83
|
self.assertEqual(
|
84
84
|
updated_entry.v_cache.numpy().flatten().tolist(),
|
85
|
-
[7, 7,
|
85
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
86
86
|
)
|
87
87
|
|
88
88
|
def test_serialization(self):
|
@@ -16,12 +16,10 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
21
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
21
|
from ai_edge_torch.generative.layers import kv_cache
|
23
22
|
from ai_edge_torch.generative.test import utils as test_utils
|
24
|
-
from ai_edge_torch.generative.utilities import model_builder
|
25
23
|
import numpy as np
|
26
24
|
import torch
|
27
25
|
|
@@ -84,25 +82,25 @@ class TestModelConversion(googletest.TestCase):
|
|
84
82
|
)
|
85
83
|
|
86
84
|
@googletest.skipIf(
|
87
|
-
|
88
|
-
reason="tests with custom ops are not supported
|
85
|
+
ai_edge_torch.config.in_oss,
|
86
|
+
reason="tests with custom ops are not supported in oss",
|
89
87
|
)
|
90
88
|
def test_toy_model_with_kv_cache(self):
|
91
89
|
self._test_model_with_kv_cache(enable_hlfb=False)
|
92
90
|
|
93
91
|
@googletest.skipIf(
|
94
|
-
|
95
|
-
reason="tests with custom ops are not supported
|
92
|
+
ai_edge_torch.config.in_oss,
|
93
|
+
reason="tests with custom ops are not supported in oss",
|
96
94
|
)
|
97
95
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
98
96
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
99
97
|
|
100
98
|
@googletest.skipIf(
|
101
|
-
|
102
|
-
reason="tests with custom ops are not supported
|
99
|
+
ai_edge_torch.config.in_oss,
|
100
|
+
reason="tests with custom ops are not supported in oss",
|
103
101
|
)
|
104
|
-
def
|
105
|
-
"""Tests that the model has the
|
102
|
+
def test_toy_model_has_dus_op(self):
|
103
|
+
"""Tests that the model has the dynamic update slice op."""
|
106
104
|
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
105
|
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
106
|
custom_op_registerers=["GenAIOpsRegisterer"],
|
@@ -112,7 +110,7 @@ class TestModelConversion(googletest.TestCase):
|
|
112
110
|
|
113
111
|
# pylint: disable=protected-access
|
114
112
|
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
-
self.assertIn("
|
113
|
+
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
116
114
|
|
117
115
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
118
116
|
# prefill
|
@@ -180,12 +178,12 @@ class TestModelConversion(googletest.TestCase):
|
|
180
178
|
)
|
181
179
|
|
182
180
|
@googletest.skipIf(
|
183
|
-
|
184
|
-
reason="tests with custom ops are not supported
|
181
|
+
ai_edge_torch.config.in_oss,
|
182
|
+
reason="tests with custom ops are not supported in oss",
|
185
183
|
)
|
186
184
|
def test_tiny_llama_multisig(self):
|
187
185
|
config = tiny_llama.get_fake_model_config()
|
188
|
-
pytorch_model =
|
186
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
189
187
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
190
188
|
|
191
189
|
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
21
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
@@ -32,7 +31,6 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_dec
|
|
32
31
|
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
33
32
|
from ai_edge_torch.generative.layers import kv_cache
|
34
33
|
from ai_edge_torch.generative.test import utils as test_utils
|
35
|
-
from ai_edge_torch.generative.utilities import model_builder
|
36
34
|
import numpy as np
|
37
35
|
import torch
|
38
36
|
|
@@ -53,12 +51,15 @@ class TestModelConversion(googletest.TestCase):
|
|
53
51
|
experimental_default_delegate_latest_features=True,
|
54
52
|
)
|
55
53
|
)
|
54
|
+
# Default cache_size_limit, 8 is hit and aborts often when the tests are
|
55
|
+
# running all together. Doubles it to avoid abortion.
|
56
|
+
torch._dynamo.config.cache_size_limit = 16
|
57
|
+
np.random.seed(1234) # Make np.random deterministic.
|
56
58
|
|
57
59
|
def _test_model(self, config, model, signature_name, atol, rtol):
|
58
|
-
|
59
|
-
tokens = torch.zeros((1,
|
60
|
-
|
61
|
-
input_pos = torch.arange(0, 10, dtype=torch.int)
|
60
|
+
seq_len = 10
|
61
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
62
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
62
63
|
kv = kv_cache.KVCache.from_model_config(config)
|
63
64
|
|
64
65
|
edge_model = ai_edge_torch.signature(
|
@@ -74,6 +75,7 @@ class TestModelConversion(googletest.TestCase):
|
|
74
75
|
self._interpreter_builder(edge_model.tflite_model())
|
75
76
|
)
|
76
77
|
|
78
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
77
79
|
self.assertTrue(
|
78
80
|
test_utils.compare_tflite_torch(
|
79
81
|
edge_model,
|
@@ -88,19 +90,17 @@ class TestModelConversion(googletest.TestCase):
|
|
88
90
|
)
|
89
91
|
|
90
92
|
@googletest.skipIf(
|
91
|
-
|
92
|
-
reason="tests with custom ops are not supported
|
93
|
+
ai_edge_torch.config.in_oss,
|
94
|
+
reason="tests with custom ops are not supported in oss",
|
93
95
|
)
|
94
96
|
def test_gemma1(self):
|
95
97
|
config = gemma1.get_fake_model_config()
|
96
|
-
pytorch_model =
|
97
|
-
self._test_model(
|
98
|
-
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
99
|
-
)
|
98
|
+
pytorch_model = gemma1.Gemma1(config).eval()
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
100
100
|
|
101
101
|
@googletest.skipIf(
|
102
|
-
|
103
|
-
reason="tests with custom ops are not supported
|
102
|
+
ai_edge_torch.config.in_oss,
|
103
|
+
reason="tests with custom ops are not supported in oss",
|
104
104
|
)
|
105
105
|
def test_gemma2(self):
|
106
106
|
config = gemma2.get_fake_model_config()
|
@@ -108,8 +108,8 @@ class TestModelConversion(googletest.TestCase):
|
|
108
108
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
109
109
|
|
110
110
|
@googletest.skipIf(
|
111
|
-
|
112
|
-
reason="tests with custom ops are not supported
|
111
|
+
ai_edge_torch.config.in_oss,
|
112
|
+
reason="tests with custom ops are not supported in oss",
|
113
113
|
)
|
114
114
|
def test_llama(self):
|
115
115
|
config = llama.get_fake_model_config()
|
@@ -117,19 +117,18 @@ class TestModelConversion(googletest.TestCase):
|
|
117
117
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
118
118
|
|
119
119
|
@googletest.skipIf(
|
120
|
-
|
121
|
-
reason="tests with custom ops are not supported
|
120
|
+
ai_edge_torch.config.in_oss,
|
121
|
+
reason="tests with custom ops are not supported in oss",
|
122
122
|
)
|
123
123
|
def test_phi2(self):
|
124
124
|
config = phi2.get_fake_model_config()
|
125
|
-
pytorch_model =
|
126
|
-
|
127
|
-
|
128
|
-
)
|
125
|
+
pytorch_model = phi2.Phi2(config).eval()
|
126
|
+
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
127
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
129
128
|
|
130
129
|
@googletest.skipIf(
|
131
|
-
|
132
|
-
reason="tests with custom ops are not supported
|
130
|
+
ai_edge_torch.config.in_oss,
|
131
|
+
reason="tests with custom ops are not supported in oss",
|
133
132
|
)
|
134
133
|
def test_phi3(self):
|
135
134
|
config = phi3.get_fake_model_config()
|
@@ -137,58 +136,58 @@ class TestModelConversion(googletest.TestCase):
|
|
137
136
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
138
137
|
|
139
138
|
@googletest.skipIf(
|
140
|
-
|
141
|
-
reason="tests with custom ops are not supported
|
139
|
+
ai_edge_torch.config.in_oss,
|
140
|
+
reason="tests with custom ops are not supported in oss",
|
142
141
|
)
|
143
142
|
def test_smollm(self):
|
144
143
|
config = smollm.get_fake_model_config()
|
145
|
-
pytorch_model =
|
144
|
+
pytorch_model = smollm.SmolLM(config).eval()
|
146
145
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
147
146
|
|
148
147
|
@googletest.skipIf(
|
149
|
-
|
150
|
-
reason="tests with custom ops are not supported
|
148
|
+
ai_edge_torch.config.in_oss,
|
149
|
+
reason="tests with custom ops are not supported in oss",
|
151
150
|
)
|
152
151
|
def test_openelm(self):
|
153
152
|
config = openelm.get_fake_model_config()
|
154
|
-
pytorch_model =
|
153
|
+
pytorch_model = openelm.OpenELM(config).eval()
|
155
154
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
156
155
|
|
157
156
|
@googletest.skipIf(
|
158
|
-
|
159
|
-
reason="tests with custom ops are not supported
|
157
|
+
ai_edge_torch.config.in_oss,
|
158
|
+
reason="tests with custom ops are not supported in oss",
|
160
159
|
)
|
161
160
|
def test_qwen(self):
|
162
161
|
config = qwen.get_fake_model_config()
|
163
|
-
pytorch_model =
|
162
|
+
pytorch_model = qwen.Qwen(config).eval()
|
164
163
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
165
164
|
|
166
165
|
@googletest.skipIf(
|
167
|
-
|
168
|
-
reason="tests with custom ops are not supported
|
166
|
+
ai_edge_torch.config.in_oss,
|
167
|
+
reason="tests with custom ops are not supported in oss",
|
169
168
|
)
|
170
169
|
def test_amd_llama_135m(self):
|
171
170
|
config = amd_llama_135m.get_fake_model_config()
|
172
|
-
pytorch_model =
|
173
|
-
self._test_model(config, pytorch_model, "prefill", atol=1e-
|
171
|
+
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
172
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
174
173
|
|
175
174
|
@googletest.skipIf(
|
176
|
-
|
177
|
-
reason="tests with custom ops are not supported
|
175
|
+
ai_edge_torch.config.in_oss,
|
176
|
+
reason="tests with custom ops are not supported in oss",
|
178
177
|
)
|
179
|
-
def
|
178
|
+
def disabled_test_paligemma(self):
|
180
179
|
config = paligemma.get_fake_model_config()
|
181
180
|
pytorch_model = paligemma.PaliGemma(config).eval()
|
182
|
-
|
181
|
+
|
183
182
|
image_embedding_config = config.image_encoder_config.image_embedding
|
184
183
|
num_patches = (
|
185
184
|
image_embedding_config.image_size // image_embedding_config.patch_size
|
186
185
|
) ** 2
|
186
|
+
|
187
187
|
# Make sure the token size is longer than the number of image patches.
|
188
|
-
|
189
|
-
tokens = torch.zeros((1,
|
190
|
-
|
191
|
-
input_pos = torch.arange(0, tokens_len, dtype=torch.int)
|
188
|
+
seq_len = num_patches + 10
|
189
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
190
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
192
191
|
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
193
192
|
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
|
194
193
|
|
@@ -206,6 +205,7 @@ class TestModelConversion(googletest.TestCase):
|
|
206
205
|
self._interpreter_builder(edge_model.tflite_model())
|
207
206
|
)
|
208
207
|
|
208
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
209
209
|
self.assertTrue(
|
210
210
|
test_utils.compare_tflite_torch(
|
211
211
|
edge_model,
|
@@ -221,8 +221,8 @@ class TestModelConversion(googletest.TestCase):
|
|
221
221
|
)
|
222
222
|
|
223
223
|
@googletest.skipIf(
|
224
|
-
|
225
|
-
reason="tests with custom ops are not supported
|
224
|
+
ai_edge_torch.config.in_oss,
|
225
|
+
reason="tests with custom ops are not supported in oss",
|
226
226
|
)
|
227
227
|
def test_stable_diffusion_clip(self):
|
228
228
|
config = sd_clip.get_fake_model_config()
|
@@ -244,7 +244,7 @@ class TestModelConversion(googletest.TestCase):
|
|
244
244
|
signature_name="encode",
|
245
245
|
)
|
246
246
|
self.assertTrue(
|
247
|
-
|
247
|
+
test_utils.compare_logits(
|
248
248
|
edge_output,
|
249
249
|
torch_output.detach().numpy(),
|
250
250
|
atol=1e-4,
|
@@ -253,19 +253,21 @@ class TestModelConversion(googletest.TestCase):
|
|
253
253
|
)
|
254
254
|
|
255
255
|
@googletest.skipIf(
|
256
|
-
|
257
|
-
reason="tests with custom ops are not supported
|
256
|
+
ai_edge_torch.config.in_oss,
|
257
|
+
reason="tests with custom ops are not supported in oss",
|
258
258
|
)
|
259
259
|
def test_stable_diffusion_diffusion(self):
|
260
260
|
config = sd_diffusion.get_fake_model_config(2)
|
261
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
262
|
+
# fails comparisons with reasonable tolerances.
|
261
263
|
latents = torch.from_numpy(
|
262
|
-
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
|
264
|
+
np.random.normal(size=(2, 4, 8, 8), scale=0.1).astype(np.float32)
|
263
265
|
)
|
264
266
|
context = torch.from_numpy(
|
265
|
-
np.random.normal(size=(2, 4, 4)).astype(np.float32)
|
267
|
+
np.random.normal(size=(2, 4, 4), scale=0.1).astype(np.float32)
|
266
268
|
)
|
267
269
|
time_embedding = torch.from_numpy(
|
268
|
-
np.random.normal(size=(2, 2)).astype(np.float32)
|
270
|
+
np.random.normal(size=(2, 2), scale=0.1).astype(np.float32)
|
269
271
|
)
|
270
272
|
|
271
273
|
pytorch_model = sd_diffusion.Diffusion(config).eval()
|
@@ -284,7 +286,7 @@ class TestModelConversion(googletest.TestCase):
|
|
284
286
|
signature_name="diffusion",
|
285
287
|
)
|
286
288
|
self.assertTrue(
|
287
|
-
|
289
|
+
test_utils.compare_logits(
|
288
290
|
edge_output,
|
289
291
|
torch_output.detach().numpy(),
|
290
292
|
atol=1e-4,
|
@@ -293,13 +295,15 @@ class TestModelConversion(googletest.TestCase):
|
|
293
295
|
)
|
294
296
|
|
295
297
|
@googletest.skipIf(
|
296
|
-
|
297
|
-
reason="tests with custom ops are not supported
|
298
|
+
ai_edge_torch.config.in_oss,
|
299
|
+
reason="tests with custom ops are not supported in oss",
|
298
300
|
)
|
299
301
|
def test_stable_diffusion_decoder(self):
|
300
302
|
config = sd_decoder.get_fake_model_config()
|
303
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
304
|
+
# fails comparisons with reasonable tolerances.
|
301
305
|
latents = torch.from_numpy(
|
302
|
-
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
|
306
|
+
np.random.normal(size=(1, 4, 64, 64), scale=0.1).astype(np.float32)
|
303
307
|
)
|
304
308
|
|
305
309
|
pytorch_model = sd_decoder.Decoder(config).eval()
|
@@ -316,10 +320,10 @@ class TestModelConversion(googletest.TestCase):
|
|
316
320
|
signature_name="decode",
|
317
321
|
)
|
318
322
|
self.assertTrue(
|
319
|
-
|
323
|
+
test_utils.compare_logits(
|
320
324
|
edge_output,
|
321
325
|
torch_output.detach().numpy(),
|
322
|
-
atol=1e-
|
326
|
+
atol=1e-3,
|
323
327
|
rtol=1e-5,
|
324
328
|
)
|
325
329
|
)
|
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Common utils for testing."""
|
17
17
|
|
18
|
+
import logging
|
19
|
+
|
18
20
|
from ai_edge_torch import model
|
19
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
22
|
from ai_edge_torch.lowertools import common_utils
|
@@ -33,7 +35,7 @@ def compare_tflite_torch(
|
|
33
35
|
atol: float = 1e-5,
|
34
36
|
rtol: float = 1e-5,
|
35
37
|
**kwargs,
|
36
|
-
):
|
38
|
+
) -> bool:
|
37
39
|
"""Compares torch models and TFLite models."""
|
38
40
|
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
39
41
|
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
@@ -49,9 +51,32 @@ def compare_tflite_torch(
|
|
49
51
|
**kwargs,
|
50
52
|
)
|
51
53
|
|
52
|
-
return
|
53
|
-
edge_output["logits"],
|
54
|
-
torch_output["logits"].detach().numpy(),
|
55
|
-
atol=atol,
|
56
|
-
rtol=rtol,
|
54
|
+
return compare_logits(
|
55
|
+
edge_output["logits"], torch_output["logits"].detach().numpy(), atol, rtol
|
57
56
|
)
|
57
|
+
|
58
|
+
|
59
|
+
def compare_logits(
|
60
|
+
edge_logits: np.ndarray,
|
61
|
+
torch_logits: dict[str, torch.Tensor],
|
62
|
+
atol: float = 1e-5,
|
63
|
+
rtol: float = 1e-5,
|
64
|
+
) -> bool:
|
65
|
+
"""Compares logits from edge model and torch model."""
|
66
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
67
|
+
return True
|
68
|
+
|
69
|
+
logging.info("edge_logits: %s", edge_logits)
|
70
|
+
logging.info("torch_logits: %s", torch_logits)
|
71
|
+
|
72
|
+
orig_atol = atol
|
73
|
+
while rtol < 1:
|
74
|
+
atol = orig_atol
|
75
|
+
while atol < 1:
|
76
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
77
|
+
logging.info("Got allclose true with atol=%s, rtol=%s", atol, rtol)
|
78
|
+
return False
|
79
|
+
atol *= 10
|
80
|
+
rtol *= 10
|
81
|
+
logging.info("allclose failed with reasonable atol and rtol.")
|
82
|
+
return False
|