ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241213__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +25 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +4 -5
- ai_edge_torch/generative/test/test_model_conversion_large.py +37 -32
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/RECORD +45 -44
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
22
23
|
|
23
24
|
|
25
|
+
class TinyLlama(model_builder.DecoderOnlyModel):
|
26
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for a TinyLlama model.
|
26
32
|
|
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
73
79
|
return config
|
74
80
|
|
75
81
|
|
76
|
-
def build_model(
|
77
|
-
checkpoint_path: str, **kwargs
|
78
|
-
) -> model_builder.DecoderOnlyModel:
|
82
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
79
83
|
return model_builder.build_decoder_only_model(
|
80
84
|
checkpoint_path=checkpoint_path,
|
81
85
|
config=get_model_config(**kwargs),
|
82
86
|
tensor_names=TENSOR_NAMES,
|
87
|
+
model_class=TinyLlama,
|
83
88
|
)
|
@@ -241,9 +241,7 @@ class CausalSelfAttention(nn.Module):
|
|
241
241
|
q, k = _embed_rope(q, k, n_elem, rope)
|
242
242
|
|
243
243
|
if kv_cache is not None:
|
244
|
-
kv_cache = kv_utils.update(
|
245
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
246
|
-
)
|
244
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
247
245
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
248
246
|
|
249
247
|
y = self.sdpa_func(
|
@@ -379,9 +377,7 @@ class CrossAttention(nn.Module):
|
|
379
377
|
q, k = _embed_rope(q, k, n_elem, rope)
|
380
378
|
|
381
379
|
if kv_cache is not None:
|
382
|
-
kv_cache = kv_utils.update(
|
383
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
384
|
-
)
|
380
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
385
381
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
386
382
|
if mask is None:
|
387
383
|
mask = torch.zeros(
|
@@ -146,7 +146,7 @@ def update(
|
|
146
146
|
input_pos: torch.Tensor,
|
147
147
|
k_slice: torch.Tensor,
|
148
148
|
v_slice: torch.Tensor,
|
149
|
-
|
149
|
+
use_dus: bool = True,
|
150
150
|
) -> KVCacheEntry:
|
151
151
|
"""Out of place update of Cache buffer.
|
152
152
|
|
@@ -155,17 +155,14 @@ def update(
|
|
155
155
|
input_pos (torch.Tensor): The update slice positions.
|
156
156
|
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
157
157
|
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
158
|
-
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
159
|
-
High Level Function Boundary. Defaults to True.
|
160
158
|
|
161
159
|
Returns:
|
162
160
|
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
163
161
|
"""
|
164
|
-
#
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
return update_func(cache, input_pos, k_slice, v_slice)
|
162
|
+
# Turn dynamic_update_slice updates off for now.
|
163
|
+
use_dus=False
|
164
|
+
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
|
165
|
+
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
169
166
|
|
170
167
|
|
171
168
|
def _update_kv_base_impl(
|
@@ -181,18 +178,28 @@ def _update_kv_base_impl(
|
|
181
178
|
return updated_cache
|
182
179
|
|
183
180
|
|
184
|
-
def
|
181
|
+
def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
|
182
|
+
"""Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
|
183
|
+
|
184
|
+
zero = torch.zeros([]).int()
|
185
|
+
positions = positions.int()[0].reshape([])
|
186
|
+
return [zero, positions, zero, zero]
|
187
|
+
|
188
|
+
|
189
|
+
def _update_kv_impl(
|
185
190
|
cache: KVCacheEntry,
|
186
191
|
input_pos: torch.Tensor,
|
187
192
|
k_slice: torch.Tensor,
|
188
193
|
v_slice: torch.Tensor,
|
189
194
|
) -> KVCacheEntry:
|
190
|
-
"""Update the cache buffer
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
)
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
195
|
+
"""Update the cache buffer for K and V caches."""
|
196
|
+
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
|
197
|
+
|
198
|
+
k_slice_indices = _get_slice_indices(input_pos)
|
199
|
+
v_slice_indices = _get_slice_indices(input_pos)
|
200
|
+
|
201
|
+
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
202
|
+
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
203
|
+
|
204
|
+
updated_cache = KVCacheEntry(k, v)
|
205
|
+
return updated_cache
|
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
|
|
190
190
|
"""
|
191
191
|
x = torch.permute(x, (0, 2, 3, 1))
|
192
192
|
|
193
|
-
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
194
|
-
# int32 when the bug is fixed.
|
195
193
|
builder = StableHLOCompositeBuilder(
|
196
194
|
name="odml.group_norm",
|
197
195
|
attr={
|
198
196
|
"num_groups": num_groups,
|
199
197
|
"epsilon": eps,
|
200
|
-
"reduction_axes": 3,
|
198
|
+
"reduction_axes": [3],
|
201
199
|
"channel_axis": 3,
|
202
200
|
},
|
203
201
|
)
|
@@ -71,18 +71,18 @@ class TestKVLayers(googletest.TestCase):
|
|
71
71
|
[0, 0, 5, 5, 0, 0, 0, 0],
|
72
72
|
)
|
73
73
|
# multi-slice update
|
74
|
-
input_pos = torch.tensor([0,
|
74
|
+
input_pos = torch.tensor([0, 1])
|
75
75
|
k_slice = v_slice = torch.full(
|
76
76
|
(1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
|
77
77
|
)
|
78
78
|
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
79
79
|
self.assertEqual(
|
80
80
|
updated_entry.k_cache.numpy().flatten().tolist(),
|
81
|
-
[7, 7,
|
81
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
82
82
|
)
|
83
83
|
self.assertEqual(
|
84
84
|
updated_entry.v_cache.numpy().flatten().tolist(),
|
85
|
-
[7, 7,
|
85
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
86
86
|
)
|
87
87
|
|
88
88
|
def test_serialization(self):
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
|
|
21
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache
|
23
23
|
from ai_edge_torch.generative.test import utils as test_utils
|
24
|
-
from ai_edge_torch.generative.utilities import model_builder
|
25
24
|
import numpy as np
|
26
25
|
import torch
|
27
26
|
|
@@ -101,8 +100,8 @@ class TestModelConversion(googletest.TestCase):
|
|
101
100
|
ai_edge_config.Config.use_torch_xla,
|
102
101
|
reason="tests with custom ops are not supported on oss",
|
103
102
|
)
|
104
|
-
def
|
105
|
-
"""Tests that the model has the
|
103
|
+
def test_toy_model_has_dus_op(self):
|
104
|
+
"""Tests that the model has the dynamic update slice op."""
|
106
105
|
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
106
|
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
107
|
custom_op_registerers=["GenAIOpsRegisterer"],
|
@@ -112,7 +111,7 @@ class TestModelConversion(googletest.TestCase):
|
|
112
111
|
|
113
112
|
# pylint: disable=protected-access
|
114
113
|
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
-
self.assertIn("
|
114
|
+
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
116
115
|
|
117
116
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
118
117
|
# prefill
|
@@ -185,7 +184,7 @@ class TestModelConversion(googletest.TestCase):
|
|
185
184
|
)
|
186
185
|
def test_tiny_llama_multisig(self):
|
187
186
|
config = tiny_llama.get_fake_model_config()
|
188
|
-
pytorch_model =
|
187
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
189
188
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
190
189
|
|
191
190
|
|
@@ -32,7 +32,6 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_dec
|
|
32
32
|
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
33
33
|
from ai_edge_torch.generative.layers import kv_cache
|
34
34
|
from ai_edge_torch.generative.test import utils as test_utils
|
35
|
-
from ai_edge_torch.generative.utilities import model_builder
|
36
35
|
import numpy as np
|
37
36
|
import torch
|
38
37
|
|
@@ -53,12 +52,15 @@ class TestModelConversion(googletest.TestCase):
|
|
53
52
|
experimental_default_delegate_latest_features=True,
|
54
53
|
)
|
55
54
|
)
|
55
|
+
# Default cache_size_limit, 8 is hit and aborts often when the tests are
|
56
|
+
# running all together. Doubles it to avoid abortion.
|
57
|
+
torch._dynamo.config.cache_size_limit = 16
|
58
|
+
np.random.seed(1234) # Make np.random deterministic.
|
56
59
|
|
57
60
|
def _test_model(self, config, model, signature_name, atol, rtol):
|
58
|
-
|
59
|
-
tokens = torch.zeros((1,
|
60
|
-
|
61
|
-
input_pos = torch.arange(0, 10, dtype=torch.int)
|
61
|
+
seq_len = 10
|
62
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
63
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
62
64
|
kv = kv_cache.KVCache.from_model_config(config)
|
63
65
|
|
64
66
|
edge_model = ai_edge_torch.signature(
|
@@ -74,6 +76,7 @@ class TestModelConversion(googletest.TestCase):
|
|
74
76
|
self._interpreter_builder(edge_model.tflite_model())
|
75
77
|
)
|
76
78
|
|
79
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
77
80
|
self.assertTrue(
|
78
81
|
test_utils.compare_tflite_torch(
|
79
82
|
edge_model,
|
@@ -93,10 +96,8 @@ class TestModelConversion(googletest.TestCase):
|
|
93
96
|
)
|
94
97
|
def test_gemma1(self):
|
95
98
|
config = gemma1.get_fake_model_config()
|
96
|
-
pytorch_model =
|
97
|
-
self._test_model(
|
98
|
-
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
99
|
-
)
|
99
|
+
pytorch_model = gemma1.Gemma1(config).eval()
|
100
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
100
101
|
|
101
102
|
@googletest.skipIf(
|
102
103
|
ai_edge_config.Config.use_torch_xla,
|
@@ -122,10 +123,9 @@ class TestModelConversion(googletest.TestCase):
|
|
122
123
|
)
|
123
124
|
def test_phi2(self):
|
124
125
|
config = phi2.get_fake_model_config()
|
125
|
-
pytorch_model =
|
126
|
-
|
127
|
-
|
128
|
-
)
|
126
|
+
pytorch_model = phi2.Phi2(config).eval()
|
127
|
+
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
128
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
129
129
|
|
130
130
|
@googletest.skipIf(
|
131
131
|
ai_edge_config.Config.use_torch_xla,
|
@@ -142,7 +142,7 @@ class TestModelConversion(googletest.TestCase):
|
|
142
142
|
)
|
143
143
|
def test_smollm(self):
|
144
144
|
config = smollm.get_fake_model_config()
|
145
|
-
pytorch_model =
|
145
|
+
pytorch_model = smollm.SmolLM(config).eval()
|
146
146
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
147
147
|
|
148
148
|
@googletest.skipIf(
|
@@ -151,7 +151,7 @@ class TestModelConversion(googletest.TestCase):
|
|
151
151
|
)
|
152
152
|
def test_openelm(self):
|
153
153
|
config = openelm.get_fake_model_config()
|
154
|
-
pytorch_model =
|
154
|
+
pytorch_model = openelm.OpenELM(config).eval()
|
155
155
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
156
156
|
|
157
157
|
@googletest.skipIf(
|
@@ -160,7 +160,7 @@ class TestModelConversion(googletest.TestCase):
|
|
160
160
|
)
|
161
161
|
def test_qwen(self):
|
162
162
|
config = qwen.get_fake_model_config()
|
163
|
-
pytorch_model =
|
163
|
+
pytorch_model = qwen.Qwen(config).eval()
|
164
164
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
165
165
|
|
166
166
|
@googletest.skipIf(
|
@@ -169,26 +169,26 @@ class TestModelConversion(googletest.TestCase):
|
|
169
169
|
)
|
170
170
|
def test_amd_llama_135m(self):
|
171
171
|
config = amd_llama_135m.get_fake_model_config()
|
172
|
-
pytorch_model =
|
173
|
-
self._test_model(config, pytorch_model, "prefill", atol=1e-
|
172
|
+
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
173
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
174
174
|
|
175
175
|
@googletest.skipIf(
|
176
176
|
ai_edge_config.Config.use_torch_xla,
|
177
177
|
reason="tests with custom ops are not supported on oss",
|
178
178
|
)
|
179
|
-
def
|
179
|
+
def disabled_test_paligemma(self):
|
180
180
|
config = paligemma.get_fake_model_config()
|
181
181
|
pytorch_model = paligemma.PaliGemma(config).eval()
|
182
|
-
|
182
|
+
|
183
183
|
image_embedding_config = config.image_encoder_config.image_embedding
|
184
184
|
num_patches = (
|
185
185
|
image_embedding_config.image_size // image_embedding_config.patch_size
|
186
186
|
) ** 2
|
187
|
+
|
187
188
|
# Make sure the token size is longer than the number of image patches.
|
188
|
-
|
189
|
-
tokens = torch.zeros((1,
|
190
|
-
|
191
|
-
input_pos = torch.arange(0, tokens_len, dtype=torch.int)
|
189
|
+
seq_len = num_patches + 10
|
190
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
191
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
192
192
|
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
193
193
|
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
|
194
194
|
|
@@ -206,6 +206,7 @@ class TestModelConversion(googletest.TestCase):
|
|
206
206
|
self._interpreter_builder(edge_model.tflite_model())
|
207
207
|
)
|
208
208
|
|
209
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
209
210
|
self.assertTrue(
|
210
211
|
test_utils.compare_tflite_torch(
|
211
212
|
edge_model,
|
@@ -244,7 +245,7 @@ class TestModelConversion(googletest.TestCase):
|
|
244
245
|
signature_name="encode",
|
245
246
|
)
|
246
247
|
self.assertTrue(
|
247
|
-
|
248
|
+
test_utils.compare_logits(
|
248
249
|
edge_output,
|
249
250
|
torch_output.detach().numpy(),
|
250
251
|
atol=1e-4,
|
@@ -258,14 +259,16 @@ class TestModelConversion(googletest.TestCase):
|
|
258
259
|
)
|
259
260
|
def test_stable_diffusion_diffusion(self):
|
260
261
|
config = sd_diffusion.get_fake_model_config(2)
|
262
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
263
|
+
# fails comparisons with reasonable tolerances.
|
261
264
|
latents = torch.from_numpy(
|
262
|
-
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
|
265
|
+
np.random.normal(size=(2, 4, 8, 8), scale=0.1).astype(np.float32)
|
263
266
|
)
|
264
267
|
context = torch.from_numpy(
|
265
|
-
np.random.normal(size=(2, 4, 4)).astype(np.float32)
|
268
|
+
np.random.normal(size=(2, 4, 4), scale=0.1).astype(np.float32)
|
266
269
|
)
|
267
270
|
time_embedding = torch.from_numpy(
|
268
|
-
np.random.normal(size=(2, 2)).astype(np.float32)
|
271
|
+
np.random.normal(size=(2, 2), scale=0.1).astype(np.float32)
|
269
272
|
)
|
270
273
|
|
271
274
|
pytorch_model = sd_diffusion.Diffusion(config).eval()
|
@@ -284,7 +287,7 @@ class TestModelConversion(googletest.TestCase):
|
|
284
287
|
signature_name="diffusion",
|
285
288
|
)
|
286
289
|
self.assertTrue(
|
287
|
-
|
290
|
+
test_utils.compare_logits(
|
288
291
|
edge_output,
|
289
292
|
torch_output.detach().numpy(),
|
290
293
|
atol=1e-4,
|
@@ -298,8 +301,10 @@ class TestModelConversion(googletest.TestCase):
|
|
298
301
|
)
|
299
302
|
def test_stable_diffusion_decoder(self):
|
300
303
|
config = sd_decoder.get_fake_model_config()
|
304
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
305
|
+
# fails comparisons with reasonable tolerances.
|
301
306
|
latents = torch.from_numpy(
|
302
|
-
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
|
307
|
+
np.random.normal(size=(1, 4, 64, 64), scale=0.1).astype(np.float32)
|
303
308
|
)
|
304
309
|
|
305
310
|
pytorch_model = sd_decoder.Decoder(config).eval()
|
@@ -316,10 +321,10 @@ class TestModelConversion(googletest.TestCase):
|
|
316
321
|
signature_name="decode",
|
317
322
|
)
|
318
323
|
self.assertTrue(
|
319
|
-
|
324
|
+
test_utils.compare_logits(
|
320
325
|
edge_output,
|
321
326
|
torch_output.detach().numpy(),
|
322
|
-
atol=1e-
|
327
|
+
atol=1e-3,
|
323
328
|
rtol=1e-5,
|
324
329
|
)
|
325
330
|
)
|
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Common utils for testing."""
|
17
17
|
|
18
|
+
import logging
|
19
|
+
|
18
20
|
from ai_edge_torch import model
|
19
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
22
|
from ai_edge_torch.lowertools import common_utils
|
@@ -33,7 +35,7 @@ def compare_tflite_torch(
|
|
33
35
|
atol: float = 1e-5,
|
34
36
|
rtol: float = 1e-5,
|
35
37
|
**kwargs,
|
36
|
-
):
|
38
|
+
) -> bool:
|
37
39
|
"""Compares torch models and TFLite models."""
|
38
40
|
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
39
41
|
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
@@ -49,9 +51,32 @@ def compare_tflite_torch(
|
|
49
51
|
**kwargs,
|
50
52
|
)
|
51
53
|
|
52
|
-
return
|
53
|
-
edge_output["logits"],
|
54
|
-
torch_output["logits"].detach().numpy(),
|
55
|
-
atol=atol,
|
56
|
-
rtol=rtol,
|
54
|
+
return compare_logits(
|
55
|
+
edge_output["logits"], torch_output["logits"].detach().numpy(), atol, rtol
|
57
56
|
)
|
57
|
+
|
58
|
+
|
59
|
+
def compare_logits(
|
60
|
+
edge_logits: np.ndarray,
|
61
|
+
torch_logits: dict[str, torch.Tensor],
|
62
|
+
atol: float = 1e-5,
|
63
|
+
rtol: float = 1e-5,
|
64
|
+
) -> bool:
|
65
|
+
"""Compares logits from edge model and torch model."""
|
66
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
67
|
+
return True
|
68
|
+
|
69
|
+
logging.info("edge_logits: %s", edge_logits)
|
70
|
+
logging.info("torch_logits: %s", torch_logits)
|
71
|
+
|
72
|
+
orig_atol = atol
|
73
|
+
while rtol < 1:
|
74
|
+
atol = orig_atol
|
75
|
+
while atol < 1:
|
76
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
77
|
+
logging.info("Got allclose true with atol=%s, rtol=%s", atol, rtol)
|
78
|
+
return False
|
79
|
+
atol *= 10
|
80
|
+
rtol *= 10
|
81
|
+
logging.info("allclose failed with reasonable atol and rtol.")
|
82
|
+
return False
|
@@ -15,13 +15,28 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
from
|
18
|
+
from functools import partial
|
19
|
+
from typing import Any, Union
|
19
20
|
|
20
21
|
from ai_edge_torch._convert import converter as converter_utils
|
21
22
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
22
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
24
26
|
import torch
|
27
|
+
import torch.nn as nn
|
28
|
+
|
29
|
+
|
30
|
+
class ExportableModule(torch.nn.Module):
|
31
|
+
|
32
|
+
def __init__(self, module, **extra_kwargs):
|
33
|
+
super().__init__()
|
34
|
+
self.module = module
|
35
|
+
self.extra_kwargs = extra_kwargs
|
36
|
+
|
37
|
+
def forward(self, *export_args, **export_kwargs):
|
38
|
+
full_kwargs = {**export_kwargs, **self.extra_kwargs}
|
39
|
+
return self.module(*export_args, **full_kwargs)
|
25
40
|
|
26
41
|
|
27
42
|
def convert_to_tflite(
|
@@ -31,6 +46,7 @@ def convert_to_tflite(
|
|
31
46
|
pixel_values_size: torch.Size = None,
|
32
47
|
quantize: bool = True,
|
33
48
|
config: cfg.ModelConfig = None,
|
49
|
+
export_config: ExportConfig = None,
|
34
50
|
):
|
35
51
|
"""Converts a nn.Module model to multi-signature tflite model.
|
36
52
|
|
@@ -97,6 +113,11 @@ def convert_to_tflite(
|
|
97
113
|
)
|
98
114
|
|
99
115
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
116
|
+
|
117
|
+
# For export, we create a module that captures any non-exportable,
|
118
|
+
# arugments, e.g. the generation config object.
|
119
|
+
mod = ExportableModule(pytorch_model, export_config=export_config)
|
120
|
+
|
100
121
|
converter = converter_utils.Converter()
|
101
122
|
for i in range(len(prefill_seq_lens)):
|
102
123
|
prefill_seq_len = prefill_seq_lens[i]
|
@@ -108,7 +129,7 @@ def convert_to_tflite(
|
|
108
129
|
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
109
130
|
converter.add_signature(
|
110
131
|
prefill_signature_name,
|
111
|
-
|
132
|
+
mod,
|
112
133
|
sample_kwargs={
|
113
134
|
'tokens': prefill_tokens,
|
114
135
|
'input_pos': prefill_input_pos,
|
@@ -118,7 +139,7 @@ def convert_to_tflite(
|
|
118
139
|
if prefill_pixel_values is not None:
|
119
140
|
converter.add_signature(
|
120
141
|
prefill_signature_name + '_pixel',
|
121
|
-
|
142
|
+
mod,
|
122
143
|
sample_kwargs={
|
123
144
|
'tokens': prefill_tokens,
|
124
145
|
'input_pos': prefill_input_pos,
|
@@ -129,7 +150,7 @@ def convert_to_tflite(
|
|
129
150
|
|
130
151
|
converter.add_signature(
|
131
152
|
'decode',
|
132
|
-
|
153
|
+
mod,
|
133
154
|
sample_kwargs={
|
134
155
|
'tokens': decode_token,
|
135
156
|
'input_pos': decode_input_pos,
|
@@ -16,7 +16,8 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
-
from
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
|
|
45
46
|
TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
|
46
47
|
|
47
48
|
|
49
|
+
@dataclass
|
50
|
+
class ExportConfig:
|
51
|
+
"""Model generating configuration settings."""
|
52
|
+
|
53
|
+
# On prefill signatures, should the model produce logit output?
|
54
|
+
# When False, only decode signatures will produce output.
|
55
|
+
output_logits_on_prefill: bool = False
|
56
|
+
|
57
|
+
|
48
58
|
class DecoderOnlyModel(nn.Module):
|
49
59
|
"""A simple decoder-only transformer model built from the Edge Generative API.
|
50
60
|
|
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
|
|
93
103
|
tokens: torch.Tensor,
|
94
104
|
input_pos: torch.Tensor,
|
95
105
|
kv_cache: kv_utils.KVCache,
|
106
|
+
export_config: Optional[ExportConfig] = None,
|
96
107
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
97
108
|
_, seq_len = tokens.size()
|
98
109
|
assert self.config.max_seq_len >= seq_len, (
|
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
|
|
108
119
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
109
120
|
|
110
121
|
return self.forward_with_embeds(
|
111
|
-
input_embeds, rope, mask, input_pos, kv_cache
|
122
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
112
123
|
)
|
113
124
|
|
114
125
|
def forward_with_embeds(
|
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
|
|
118
129
|
mask: torch.Tensor,
|
119
130
|
input_pos: torch.Tensor,
|
120
131
|
kv_cache: kv_utils.KVCache,
|
132
|
+
export_config: Optional[ExportConfig] = None,
|
121
133
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
122
134
|
"""Forwards the model with input embeddings."""
|
123
135
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
|
|
137
149
|
updated_kv_entires.append(kv_entry)
|
138
150
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
139
151
|
|
152
|
+
if export_config is not None:
|
153
|
+
if (
|
154
|
+
torch.numel(input_pos) > 1
|
155
|
+
and not export_config.output_logits_on_prefill
|
156
|
+
):
|
157
|
+
return {"kv_cache": updated_kv_cache}
|
158
|
+
|
140
159
|
x = self.final_norm(x)
|
141
160
|
logits = self.lm_head(x) # (b, t, vocab_size)
|
142
161
|
return {"logits": logits, "kv_cache": updated_kv_cache}
|
@@ -146,8 +165,9 @@ def build_decoder_only_model(
|
|
146
165
|
checkpoint_path: str,
|
147
166
|
config: cfg.ModelConfig,
|
148
167
|
tensor_names: loading_utils.ModelLoader.TensorNames,
|
149
|
-
|
150
|
-
|
168
|
+
model_class: type[nn.Module] = DecoderOnlyModel,
|
169
|
+
) -> nn.Module:
|
170
|
+
transformer = model_class(config)
|
151
171
|
loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
|
152
172
|
loader.load(
|
153
173
|
transformer, strict=not config.lm_head_share_weight_with_embedding
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
22
23
|
import torch
|
23
24
|
|
24
25
|
|
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
|
|
40
41
|
"""
|
41
42
|
super().__init__()
|
42
43
|
self.model = model
|
44
|
+
self.export_config = ExportConfig(output_logits_on_prefill=True)
|
43
45
|
|
44
46
|
def forward(
|
45
47
|
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
103
105
|
Returns:
|
104
106
|
The output logits and the updated KV cache.
|
105
107
|
"""
|
108
|
+
# Verification requires logit outputs on prefill for comparison.
|
109
|
+
if (
|
110
|
+
self.export_config is not None
|
111
|
+
and not self.export_config.output_logits_on_prefill
|
112
|
+
):
|
113
|
+
raise ValueError("Verifier requires logit output on prefill.")
|
106
114
|
# Since the reauthored model doesn't include keyword arguments, pass
|
107
115
|
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
108
116
|
if pixel_values is None:
|
109
|
-
output = self.model.forward(
|
117
|
+
output = self.model.forward(
|
118
|
+
tokens, input_pos, kv_cache, export_config=self.export_config
|
119
|
+
)
|
110
120
|
else:
|
111
121
|
output = self.model.forward(
|
112
|
-
tokens,
|
122
|
+
tokens,
|
123
|
+
input_pos,
|
124
|
+
kv_cache,
|
125
|
+
pixel_values=pixel_values,
|
126
|
+
export_config=self.export_config,
|
113
127
|
)
|
114
128
|
return output["logits"], output["kv_cache"]
|
115
129
|
|