ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +24 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +12 -14
- ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
22
23
|
|
23
24
|
|
25
|
+
class SmolLM(model_builder.DecoderOnlyModel):
|
26
|
+
"""A SmolLM model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for a SmolLM 135M model.
|
26
32
|
|
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
72
78
|
return config
|
73
79
|
|
74
80
|
|
75
|
-
def build_model(
|
76
|
-
checkpoint_path: str, **kwargs
|
77
|
-
) -> model_builder.DecoderOnlyModel:
|
81
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
78
82
|
return model_builder.build_decoder_only_model(
|
79
83
|
checkpoint_path=checkpoint_path,
|
80
84
|
config=get_model_config(**kwargs),
|
81
85
|
tensor_names=TENSOR_NAMES,
|
86
|
+
model_class=SmolLM,
|
82
87
|
)
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""A toy example which has basic transformer block (w/ externalized KV-Cache)."""
|
17
17
|
|
18
|
-
from typing import Tuple
|
18
|
+
from typing import Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
import torch
|
26
27
|
from torch import nn
|
27
28
|
|
@@ -62,6 +63,7 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
62
63
|
tokens: torch.Tensor,
|
63
64
|
input_pos: torch.Tensor,
|
64
65
|
kv_cache: kv_utils.KVCache,
|
66
|
+
export_config: Optional[ExportConfig] = None,
|
65
67
|
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
|
66
68
|
x = self.tok_embedding(tokens)
|
67
69
|
cos, sin = self.rope_cache
|
@@ -77,8 +79,16 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
77
79
|
if kv_entry:
|
78
80
|
updated_kv_entires.append(kv_entry)
|
79
81
|
|
80
|
-
x = self.final_norm(x)
|
81
82
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
83
|
+
|
84
|
+
if export_config is not None:
|
85
|
+
if (
|
86
|
+
torch.numel(input_pos) > 1
|
87
|
+
and not export_config.output_logits_on_prefill
|
88
|
+
):
|
89
|
+
return {'kv_cache': updated_kv_cache}
|
90
|
+
|
91
|
+
x = self.final_norm(x)
|
82
92
|
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
83
93
|
|
84
94
|
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -63,6 +64,7 @@ def main(_):
|
|
63
64
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
64
65
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
66
|
quantize=_QUANTIZE.value,
|
67
|
+
export_config=ExportConfig(),
|
66
68
|
)
|
67
69
|
|
68
70
|
|
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
22
23
|
|
23
24
|
|
25
|
+
class TinyLlama(model_builder.DecoderOnlyModel):
|
26
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for a TinyLlama model.
|
26
32
|
|
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
73
79
|
return config
|
74
80
|
|
75
81
|
|
76
|
-
def build_model(
|
77
|
-
checkpoint_path: str, **kwargs
|
78
|
-
) -> model_builder.DecoderOnlyModel:
|
82
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
79
83
|
return model_builder.build_decoder_only_model(
|
80
84
|
checkpoint_path=checkpoint_path,
|
81
85
|
config=get_model_config(**kwargs),
|
82
86
|
tensor_names=TENSOR_NAMES,
|
87
|
+
model_class=TinyLlama,
|
83
88
|
)
|
@@ -241,9 +241,7 @@ class CausalSelfAttention(nn.Module):
|
|
241
241
|
q, k = _embed_rope(q, k, n_elem, rope)
|
242
242
|
|
243
243
|
if kv_cache is not None:
|
244
|
-
kv_cache = kv_utils.update(
|
245
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
246
|
-
)
|
244
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
247
245
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
248
246
|
|
249
247
|
y = self.sdpa_func(
|
@@ -379,9 +377,7 @@ class CrossAttention(nn.Module):
|
|
379
377
|
q, k = _embed_rope(q, k, n_elem, rope)
|
380
378
|
|
381
379
|
if kv_cache is not None:
|
382
|
-
kv_cache = kv_utils.update(
|
383
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
384
|
-
)
|
380
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
385
381
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
386
382
|
if mask is None:
|
387
383
|
mask = torch.zeros(
|
@@ -20,6 +20,7 @@ from typing import List, Tuple
|
|
20
20
|
|
21
21
|
from ai_edge_torch import hlfb
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
|
+
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
|
23
24
|
import torch
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
@@ -146,7 +147,7 @@ def update(
|
|
146
147
|
input_pos: torch.Tensor,
|
147
148
|
k_slice: torch.Tensor,
|
148
149
|
v_slice: torch.Tensor,
|
149
|
-
|
150
|
+
use_dus: bool = True,
|
150
151
|
) -> KVCacheEntry:
|
151
152
|
"""Out of place update of Cache buffer.
|
152
153
|
|
@@ -155,17 +156,12 @@ def update(
|
|
155
156
|
input_pos (torch.Tensor): The update slice positions.
|
156
157
|
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
157
158
|
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
158
|
-
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
159
|
-
High Level Function Boundary. Defaults to True.
|
160
159
|
|
161
160
|
Returns:
|
162
161
|
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
163
162
|
"""
|
164
|
-
|
165
|
-
|
166
|
-
enable_hlfb=False
|
167
|
-
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
|
168
|
-
return update_func(cache, input_pos, k_slice, v_slice)
|
163
|
+
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
|
164
|
+
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
169
165
|
|
170
166
|
|
171
167
|
def _update_kv_base_impl(
|
@@ -181,18 +177,28 @@ def _update_kv_base_impl(
|
|
181
177
|
return updated_cache
|
182
178
|
|
183
179
|
|
184
|
-
def
|
180
|
+
def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
|
181
|
+
"""Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
|
182
|
+
|
183
|
+
zero = torch.zeros([]).int()
|
184
|
+
positions = positions.int()[0].reshape([])
|
185
|
+
return [zero, positions, zero, zero]
|
186
|
+
|
187
|
+
|
188
|
+
def _update_kv_impl(
|
185
189
|
cache: KVCacheEntry,
|
186
190
|
input_pos: torch.Tensor,
|
187
191
|
k_slice: torch.Tensor,
|
188
192
|
v_slice: torch.Tensor,
|
189
193
|
) -> KVCacheEntry:
|
190
|
-
"""Update the cache buffer
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
)
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
194
|
+
"""Update the cache buffer for K and V caches."""
|
195
|
+
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
|
196
|
+
|
197
|
+
k_slice_indices = _get_slice_indices(input_pos)
|
198
|
+
v_slice_indices = _get_slice_indices(input_pos)
|
199
|
+
|
200
|
+
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
201
|
+
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
202
|
+
|
203
|
+
updated_cache = KVCacheEntry(k, v)
|
204
|
+
return updated_cache
|
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
|
|
190
190
|
"""
|
191
191
|
x = torch.permute(x, (0, 2, 3, 1))
|
192
192
|
|
193
|
-
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
194
|
-
# int32 when the bug is fixed.
|
195
193
|
builder = StableHLOCompositeBuilder(
|
196
194
|
name="odml.group_norm",
|
197
195
|
attr={
|
198
196
|
"num_groups": num_groups,
|
199
197
|
"epsilon": eps,
|
200
|
-
"reduction_axes": 3,
|
198
|
+
"reduction_axes": [3],
|
201
199
|
"channel_axis": 3,
|
202
200
|
},
|
203
201
|
)
|
@@ -71,18 +71,18 @@ class TestKVLayers(googletest.TestCase):
|
|
71
71
|
[0, 0, 5, 5, 0, 0, 0, 0],
|
72
72
|
)
|
73
73
|
# multi-slice update
|
74
|
-
input_pos = torch.tensor([0,
|
74
|
+
input_pos = torch.tensor([0, 1])
|
75
75
|
k_slice = v_slice = torch.full(
|
76
76
|
(1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
|
77
77
|
)
|
78
78
|
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
79
79
|
self.assertEqual(
|
80
80
|
updated_entry.k_cache.numpy().flatten().tolist(),
|
81
|
-
[7, 7,
|
81
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
82
82
|
)
|
83
83
|
self.assertEqual(
|
84
84
|
updated_entry.v_cache.numpy().flatten().tolist(),
|
85
|
-
[7, 7,
|
85
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
86
86
|
)
|
87
87
|
|
88
88
|
def test_serialization(self):
|
@@ -16,12 +16,10 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
21
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
21
|
from ai_edge_torch.generative.layers import kv_cache
|
23
22
|
from ai_edge_torch.generative.test import utils as test_utils
|
24
|
-
from ai_edge_torch.generative.utilities import model_builder
|
25
23
|
import numpy as np
|
26
24
|
import torch
|
27
25
|
|
@@ -84,25 +82,25 @@ class TestModelConversion(googletest.TestCase):
|
|
84
82
|
)
|
85
83
|
|
86
84
|
@googletest.skipIf(
|
87
|
-
|
88
|
-
reason="tests with custom ops are not supported
|
85
|
+
ai_edge_torch.config.in_oss,
|
86
|
+
reason="tests with custom ops are not supported in oss",
|
89
87
|
)
|
90
88
|
def test_toy_model_with_kv_cache(self):
|
91
89
|
self._test_model_with_kv_cache(enable_hlfb=False)
|
92
90
|
|
93
91
|
@googletest.skipIf(
|
94
|
-
|
95
|
-
reason="tests with custom ops are not supported
|
92
|
+
ai_edge_torch.config.in_oss,
|
93
|
+
reason="tests with custom ops are not supported in oss",
|
96
94
|
)
|
97
95
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
98
96
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
99
97
|
|
100
98
|
@googletest.skipIf(
|
101
|
-
|
102
|
-
reason="tests with custom ops are not supported
|
99
|
+
ai_edge_torch.config.in_oss,
|
100
|
+
reason="tests with custom ops are not supported in oss",
|
103
101
|
)
|
104
|
-
def
|
105
|
-
"""Tests that the model has the
|
102
|
+
def test_toy_model_has_dus_op(self):
|
103
|
+
"""Tests that the model has the dynamic update slice op."""
|
106
104
|
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
105
|
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
106
|
custom_op_registerers=["GenAIOpsRegisterer"],
|
@@ -112,7 +110,7 @@ class TestModelConversion(googletest.TestCase):
|
|
112
110
|
|
113
111
|
# pylint: disable=protected-access
|
114
112
|
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
-
self.assertIn("
|
113
|
+
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
116
114
|
|
117
115
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
118
116
|
# prefill
|
@@ -180,12 +178,12 @@ class TestModelConversion(googletest.TestCase):
|
|
180
178
|
)
|
181
179
|
|
182
180
|
@googletest.skipIf(
|
183
|
-
|
184
|
-
reason="tests with custom ops are not supported
|
181
|
+
ai_edge_torch.config.in_oss,
|
182
|
+
reason="tests with custom ops are not supported in oss",
|
185
183
|
)
|
186
184
|
def test_tiny_llama_multisig(self):
|
187
185
|
config = tiny_llama.get_fake_model_config()
|
188
|
-
pytorch_model =
|
186
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
189
187
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
190
188
|
|
191
189
|
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
|
-
from ai_edge_torch import config as ai_edge_config
|
20
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
21
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
@@ -32,7 +31,6 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_dec
|
|
32
31
|
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
33
32
|
from ai_edge_torch.generative.layers import kv_cache
|
34
33
|
from ai_edge_torch.generative.test import utils as test_utils
|
35
|
-
from ai_edge_torch.generative.utilities import model_builder
|
36
34
|
import numpy as np
|
37
35
|
import torch
|
38
36
|
|
@@ -53,12 +51,15 @@ class TestModelConversion(googletest.TestCase):
|
|
53
51
|
experimental_default_delegate_latest_features=True,
|
54
52
|
)
|
55
53
|
)
|
54
|
+
# Default cache_size_limit, 8 is hit and aborts often when the tests are
|
55
|
+
# running all together. Doubles it to avoid abortion.
|
56
|
+
torch._dynamo.config.cache_size_limit = 16
|
57
|
+
np.random.seed(1234) # Make np.random deterministic.
|
56
58
|
|
57
59
|
def _test_model(self, config, model, signature_name, atol, rtol):
|
58
|
-
|
59
|
-
tokens = torch.zeros((1,
|
60
|
-
|
61
|
-
input_pos = torch.arange(0, 10, dtype=torch.int)
|
60
|
+
seq_len = 10
|
61
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
62
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
62
63
|
kv = kv_cache.KVCache.from_model_config(config)
|
63
64
|
|
64
65
|
edge_model = ai_edge_torch.signature(
|
@@ -74,6 +75,7 @@ class TestModelConversion(googletest.TestCase):
|
|
74
75
|
self._interpreter_builder(edge_model.tflite_model())
|
75
76
|
)
|
76
77
|
|
78
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
77
79
|
self.assertTrue(
|
78
80
|
test_utils.compare_tflite_torch(
|
79
81
|
edge_model,
|
@@ -88,19 +90,17 @@ class TestModelConversion(googletest.TestCase):
|
|
88
90
|
)
|
89
91
|
|
90
92
|
@googletest.skipIf(
|
91
|
-
|
92
|
-
reason="tests with custom ops are not supported
|
93
|
+
ai_edge_torch.config.in_oss,
|
94
|
+
reason="tests with custom ops are not supported in oss",
|
93
95
|
)
|
94
96
|
def test_gemma1(self):
|
95
97
|
config = gemma1.get_fake_model_config()
|
96
|
-
pytorch_model =
|
97
|
-
self._test_model(
|
98
|
-
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
99
|
-
)
|
98
|
+
pytorch_model = gemma1.Gemma1(config).eval()
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
100
100
|
|
101
101
|
@googletest.skipIf(
|
102
|
-
|
103
|
-
reason="tests with custom ops are not supported
|
102
|
+
ai_edge_torch.config.in_oss,
|
103
|
+
reason="tests with custom ops are not supported in oss",
|
104
104
|
)
|
105
105
|
def test_gemma2(self):
|
106
106
|
config = gemma2.get_fake_model_config()
|
@@ -108,8 +108,8 @@ class TestModelConversion(googletest.TestCase):
|
|
108
108
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
109
109
|
|
110
110
|
@googletest.skipIf(
|
111
|
-
|
112
|
-
reason="tests with custom ops are not supported
|
111
|
+
ai_edge_torch.config.in_oss,
|
112
|
+
reason="tests with custom ops are not supported in oss",
|
113
113
|
)
|
114
114
|
def test_llama(self):
|
115
115
|
config = llama.get_fake_model_config()
|
@@ -117,19 +117,18 @@ class TestModelConversion(googletest.TestCase):
|
|
117
117
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
118
118
|
|
119
119
|
@googletest.skipIf(
|
120
|
-
|
121
|
-
reason="tests with custom ops are not supported
|
120
|
+
ai_edge_torch.config.in_oss,
|
121
|
+
reason="tests with custom ops are not supported in oss",
|
122
122
|
)
|
123
123
|
def test_phi2(self):
|
124
124
|
config = phi2.get_fake_model_config()
|
125
|
-
pytorch_model =
|
126
|
-
|
127
|
-
|
128
|
-
)
|
125
|
+
pytorch_model = phi2.Phi2(config).eval()
|
126
|
+
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
127
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
129
128
|
|
130
129
|
@googletest.skipIf(
|
131
|
-
|
132
|
-
reason="tests with custom ops are not supported
|
130
|
+
ai_edge_torch.config.in_oss,
|
131
|
+
reason="tests with custom ops are not supported in oss",
|
133
132
|
)
|
134
133
|
def test_phi3(self):
|
135
134
|
config = phi3.get_fake_model_config()
|
@@ -137,58 +136,58 @@ class TestModelConversion(googletest.TestCase):
|
|
137
136
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
138
137
|
|
139
138
|
@googletest.skipIf(
|
140
|
-
|
141
|
-
reason="tests with custom ops are not supported
|
139
|
+
ai_edge_torch.config.in_oss,
|
140
|
+
reason="tests with custom ops are not supported in oss",
|
142
141
|
)
|
143
142
|
def test_smollm(self):
|
144
143
|
config = smollm.get_fake_model_config()
|
145
|
-
pytorch_model =
|
144
|
+
pytorch_model = smollm.SmolLM(config).eval()
|
146
145
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
147
146
|
|
148
147
|
@googletest.skipIf(
|
149
|
-
|
150
|
-
reason="tests with custom ops are not supported
|
148
|
+
ai_edge_torch.config.in_oss,
|
149
|
+
reason="tests with custom ops are not supported in oss",
|
151
150
|
)
|
152
151
|
def test_openelm(self):
|
153
152
|
config = openelm.get_fake_model_config()
|
154
|
-
pytorch_model =
|
153
|
+
pytorch_model = openelm.OpenELM(config).eval()
|
155
154
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
156
155
|
|
157
156
|
@googletest.skipIf(
|
158
|
-
|
159
|
-
reason="tests with custom ops are not supported
|
157
|
+
ai_edge_torch.config.in_oss,
|
158
|
+
reason="tests with custom ops are not supported in oss",
|
160
159
|
)
|
161
160
|
def test_qwen(self):
|
162
161
|
config = qwen.get_fake_model_config()
|
163
|
-
pytorch_model =
|
162
|
+
pytorch_model = qwen.Qwen(config).eval()
|
164
163
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
165
164
|
|
166
165
|
@googletest.skipIf(
|
167
|
-
|
168
|
-
reason="tests with custom ops are not supported
|
166
|
+
ai_edge_torch.config.in_oss,
|
167
|
+
reason="tests with custom ops are not supported in oss",
|
169
168
|
)
|
170
169
|
def test_amd_llama_135m(self):
|
171
170
|
config = amd_llama_135m.get_fake_model_config()
|
172
|
-
pytorch_model =
|
173
|
-
self._test_model(config, pytorch_model, "prefill", atol=1e-
|
171
|
+
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
172
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
174
173
|
|
175
174
|
@googletest.skipIf(
|
176
|
-
|
177
|
-
reason="tests with custom ops are not supported
|
175
|
+
ai_edge_torch.config.in_oss,
|
176
|
+
reason="tests with custom ops are not supported in oss",
|
178
177
|
)
|
179
|
-
def
|
178
|
+
def disabled_test_paligemma(self):
|
180
179
|
config = paligemma.get_fake_model_config()
|
181
180
|
pytorch_model = paligemma.PaliGemma(config).eval()
|
182
|
-
|
181
|
+
|
183
182
|
image_embedding_config = config.image_encoder_config.image_embedding
|
184
183
|
num_patches = (
|
185
184
|
image_embedding_config.image_size // image_embedding_config.patch_size
|
186
185
|
) ** 2
|
186
|
+
|
187
187
|
# Make sure the token size is longer than the number of image patches.
|
188
|
-
|
189
|
-
tokens = torch.zeros((1,
|
190
|
-
|
191
|
-
input_pos = torch.arange(0, tokens_len, dtype=torch.int)
|
188
|
+
seq_len = num_patches + 10
|
189
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
190
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
192
191
|
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
193
192
|
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
|
194
193
|
|
@@ -206,6 +205,7 @@ class TestModelConversion(googletest.TestCase):
|
|
206
205
|
self._interpreter_builder(edge_model.tflite_model())
|
207
206
|
)
|
208
207
|
|
208
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
209
209
|
self.assertTrue(
|
210
210
|
test_utils.compare_tflite_torch(
|
211
211
|
edge_model,
|
@@ -221,8 +221,8 @@ class TestModelConversion(googletest.TestCase):
|
|
221
221
|
)
|
222
222
|
|
223
223
|
@googletest.skipIf(
|
224
|
-
|
225
|
-
reason="tests with custom ops are not supported
|
224
|
+
ai_edge_torch.config.in_oss,
|
225
|
+
reason="tests with custom ops are not supported in oss",
|
226
226
|
)
|
227
227
|
def test_stable_diffusion_clip(self):
|
228
228
|
config = sd_clip.get_fake_model_config()
|
@@ -244,7 +244,7 @@ class TestModelConversion(googletest.TestCase):
|
|
244
244
|
signature_name="encode",
|
245
245
|
)
|
246
246
|
self.assertTrue(
|
247
|
-
|
247
|
+
test_utils.compare_logits(
|
248
248
|
edge_output,
|
249
249
|
torch_output.detach().numpy(),
|
250
250
|
atol=1e-4,
|
@@ -253,19 +253,21 @@ class TestModelConversion(googletest.TestCase):
|
|
253
253
|
)
|
254
254
|
|
255
255
|
@googletest.skipIf(
|
256
|
-
|
257
|
-
reason="tests with custom ops are not supported
|
256
|
+
ai_edge_torch.config.in_oss,
|
257
|
+
reason="tests with custom ops are not supported in oss",
|
258
258
|
)
|
259
259
|
def test_stable_diffusion_diffusion(self):
|
260
260
|
config = sd_diffusion.get_fake_model_config(2)
|
261
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
262
|
+
# fails comparisons with reasonable tolerances.
|
261
263
|
latents = torch.from_numpy(
|
262
|
-
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
|
264
|
+
np.random.normal(size=(2, 4, 8, 8), scale=0.1).astype(np.float32)
|
263
265
|
)
|
264
266
|
context = torch.from_numpy(
|
265
|
-
np.random.normal(size=(2, 4, 4)).astype(np.float32)
|
267
|
+
np.random.normal(size=(2, 4, 4), scale=0.1).astype(np.float32)
|
266
268
|
)
|
267
269
|
time_embedding = torch.from_numpy(
|
268
|
-
np.random.normal(size=(2, 2)).astype(np.float32)
|
270
|
+
np.random.normal(size=(2, 2), scale=0.1).astype(np.float32)
|
269
271
|
)
|
270
272
|
|
271
273
|
pytorch_model = sd_diffusion.Diffusion(config).eval()
|
@@ -284,7 +286,7 @@ class TestModelConversion(googletest.TestCase):
|
|
284
286
|
signature_name="diffusion",
|
285
287
|
)
|
286
288
|
self.assertTrue(
|
287
|
-
|
289
|
+
test_utils.compare_logits(
|
288
290
|
edge_output,
|
289
291
|
torch_output.detach().numpy(),
|
290
292
|
atol=1e-4,
|
@@ -293,13 +295,15 @@ class TestModelConversion(googletest.TestCase):
|
|
293
295
|
)
|
294
296
|
|
295
297
|
@googletest.skipIf(
|
296
|
-
|
297
|
-
reason="tests with custom ops are not supported
|
298
|
+
ai_edge_torch.config.in_oss,
|
299
|
+
reason="tests with custom ops are not supported in oss",
|
298
300
|
)
|
299
301
|
def test_stable_diffusion_decoder(self):
|
300
302
|
config = sd_decoder.get_fake_model_config()
|
303
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
304
|
+
# fails comparisons with reasonable tolerances.
|
301
305
|
latents = torch.from_numpy(
|
302
|
-
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
|
306
|
+
np.random.normal(size=(1, 4, 64, 64), scale=0.1).astype(np.float32)
|
303
307
|
)
|
304
308
|
|
305
309
|
pytorch_model = sd_decoder.Decoder(config).eval()
|
@@ -316,10 +320,10 @@ class TestModelConversion(googletest.TestCase):
|
|
316
320
|
signature_name="decode",
|
317
321
|
)
|
318
322
|
self.assertTrue(
|
319
|
-
|
323
|
+
test_utils.compare_logits(
|
320
324
|
edge_output,
|
321
325
|
torch_output.detach().numpy(),
|
322
|
-
atol=1e-
|
326
|
+
atol=1e-3,
|
323
327
|
rtol=1e-5,
|
324
328
|
)
|
325
329
|
)
|
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Common utils for testing."""
|
17
17
|
|
18
|
+
import logging
|
19
|
+
|
18
20
|
from ai_edge_torch import model
|
19
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
22
|
from ai_edge_torch.lowertools import common_utils
|
@@ -33,7 +35,7 @@ def compare_tflite_torch(
|
|
33
35
|
atol: float = 1e-5,
|
34
36
|
rtol: float = 1e-5,
|
35
37
|
**kwargs,
|
36
|
-
):
|
38
|
+
) -> bool:
|
37
39
|
"""Compares torch models and TFLite models."""
|
38
40
|
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
39
41
|
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
@@ -49,9 +51,32 @@ def compare_tflite_torch(
|
|
49
51
|
**kwargs,
|
50
52
|
)
|
51
53
|
|
52
|
-
return
|
53
|
-
edge_output["logits"],
|
54
|
-
torch_output["logits"].detach().numpy(),
|
55
|
-
atol=atol,
|
56
|
-
rtol=rtol,
|
54
|
+
return compare_logits(
|
55
|
+
edge_output["logits"], torch_output["logits"].detach().numpy(), atol, rtol
|
57
56
|
)
|
57
|
+
|
58
|
+
|
59
|
+
def compare_logits(
|
60
|
+
edge_logits: np.ndarray,
|
61
|
+
torch_logits: dict[str, torch.Tensor],
|
62
|
+
atol: float = 1e-5,
|
63
|
+
rtol: float = 1e-5,
|
64
|
+
) -> bool:
|
65
|
+
"""Compares logits from edge model and torch model."""
|
66
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
67
|
+
return True
|
68
|
+
|
69
|
+
logging.info("edge_logits: %s", edge_logits)
|
70
|
+
logging.info("torch_logits: %s", torch_logits)
|
71
|
+
|
72
|
+
orig_atol = atol
|
73
|
+
while rtol < 1:
|
74
|
+
atol = orig_atol
|
75
|
+
while atol < 1:
|
76
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
77
|
+
logging.info("Got allclose true with atol=%s, rtol=%s", atol, rtol)
|
78
|
+
return False
|
79
|
+
atol *= 10
|
80
|
+
rtol *= 10
|
81
|
+
logging.info("allclose failed with reasonable atol and rtol.")
|
82
|
+
return False
|