ai-edge-torch-nightly 0.3.0.dev20241205__py3-none-any.whl → 0.3.0.dev20241213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +25 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +4 -5
- ai_edge_torch/generative/test/test_model_conversion_large.py +37 -32
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241205.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20241205.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/RECORD +45 -44
- {ai_edge_torch_nightly-0.3.0.dev20241205.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241205.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241205.dist-info → ai_edge_torch_nightly-0.3.0.dev20241213.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
22
23
|
|
23
24
|
|
25
|
+
class TinyLlama(model_builder.DecoderOnlyModel):
|
26
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for a TinyLlama model.
|
26
32
|
|
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
73
79
|
return config
|
74
80
|
|
75
81
|
|
76
|
-
def build_model(
|
77
|
-
checkpoint_path: str, **kwargs
|
78
|
-
) -> model_builder.DecoderOnlyModel:
|
82
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
79
83
|
return model_builder.build_decoder_only_model(
|
80
84
|
checkpoint_path=checkpoint_path,
|
81
85
|
config=get_model_config(**kwargs),
|
82
86
|
tensor_names=TENSOR_NAMES,
|
87
|
+
model_class=TinyLlama,
|
83
88
|
)
|
@@ -241,9 +241,7 @@ class CausalSelfAttention(nn.Module):
|
|
241
241
|
q, k = _embed_rope(q, k, n_elem, rope)
|
242
242
|
|
243
243
|
if kv_cache is not None:
|
244
|
-
kv_cache = kv_utils.update(
|
245
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
246
|
-
)
|
244
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
247
245
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
248
246
|
|
249
247
|
y = self.sdpa_func(
|
@@ -379,9 +377,7 @@ class CrossAttention(nn.Module):
|
|
379
377
|
q, k = _embed_rope(q, k, n_elem, rope)
|
380
378
|
|
381
379
|
if kv_cache is not None:
|
382
|
-
kv_cache = kv_utils.update(
|
383
|
-
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
384
|
-
)
|
380
|
+
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
385
381
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
386
382
|
if mask is None:
|
387
383
|
mask = torch.zeros(
|
@@ -146,7 +146,7 @@ def update(
|
|
146
146
|
input_pos: torch.Tensor,
|
147
147
|
k_slice: torch.Tensor,
|
148
148
|
v_slice: torch.Tensor,
|
149
|
-
|
149
|
+
use_dus: bool = True,
|
150
150
|
) -> KVCacheEntry:
|
151
151
|
"""Out of place update of Cache buffer.
|
152
152
|
|
@@ -155,17 +155,14 @@ def update(
|
|
155
155
|
input_pos (torch.Tensor): The update slice positions.
|
156
156
|
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
157
157
|
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
158
|
-
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
159
|
-
High Level Function Boundary. Defaults to True.
|
160
158
|
|
161
159
|
Returns:
|
162
160
|
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
163
161
|
"""
|
164
|
-
#
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
return update_func(cache, input_pos, k_slice, v_slice)
|
162
|
+
# Turn dynamic_update_slice updates off for now.
|
163
|
+
use_dus=False
|
164
|
+
update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
|
165
|
+
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
169
166
|
|
170
167
|
|
171
168
|
def _update_kv_base_impl(
|
@@ -181,18 +178,28 @@ def _update_kv_base_impl(
|
|
181
178
|
return updated_cache
|
182
179
|
|
183
180
|
|
184
|
-
def
|
181
|
+
def _get_slice_indices(positions: torch.Tensor) -> torch.Tensor:
|
182
|
+
"""Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
|
183
|
+
|
184
|
+
zero = torch.zeros([]).int()
|
185
|
+
positions = positions.int()[0].reshape([])
|
186
|
+
return [zero, positions, zero, zero]
|
187
|
+
|
188
|
+
|
189
|
+
def _update_kv_impl(
|
185
190
|
cache: KVCacheEntry,
|
186
191
|
input_pos: torch.Tensor,
|
187
192
|
k_slice: torch.Tensor,
|
188
193
|
v_slice: torch.Tensor,
|
189
194
|
) -> KVCacheEntry:
|
190
|
-
"""Update the cache buffer
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
)
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
195
|
+
"""Update the cache buffer for K and V caches."""
|
196
|
+
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
|
197
|
+
|
198
|
+
k_slice_indices = _get_slice_indices(input_pos)
|
199
|
+
v_slice_indices = _get_slice_indices(input_pos)
|
200
|
+
|
201
|
+
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
202
|
+
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
203
|
+
|
204
|
+
updated_cache = KVCacheEntry(k, v)
|
205
|
+
return updated_cache
|
@@ -190,14 +190,12 @@ def group_norm_with_hlfb(
|
|
190
190
|
"""
|
191
191
|
x = torch.permute(x, (0, 2, 3, 1))
|
192
192
|
|
193
|
-
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
194
|
-
# int32 when the bug is fixed.
|
195
193
|
builder = StableHLOCompositeBuilder(
|
196
194
|
name="odml.group_norm",
|
197
195
|
attr={
|
198
196
|
"num_groups": num_groups,
|
199
197
|
"epsilon": eps,
|
200
|
-
"reduction_axes": 3,
|
198
|
+
"reduction_axes": [3],
|
201
199
|
"channel_axis": 3,
|
202
200
|
},
|
203
201
|
)
|
@@ -71,18 +71,18 @@ class TestKVLayers(googletest.TestCase):
|
|
71
71
|
[0, 0, 5, 5, 0, 0, 0, 0],
|
72
72
|
)
|
73
73
|
# multi-slice update
|
74
|
-
input_pos = torch.tensor([0,
|
74
|
+
input_pos = torch.tensor([0, 1])
|
75
75
|
k_slice = v_slice = torch.full(
|
76
76
|
(1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
|
77
77
|
)
|
78
78
|
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
79
79
|
self.assertEqual(
|
80
80
|
updated_entry.k_cache.numpy().flatten().tolist(),
|
81
|
-
[7, 7,
|
81
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
82
82
|
)
|
83
83
|
self.assertEqual(
|
84
84
|
updated_entry.v_cache.numpy().flatten().tolist(),
|
85
|
-
[7, 7,
|
85
|
+
[7, 7, 7, 7, 0, 0, 0, 0],
|
86
86
|
)
|
87
87
|
|
88
88
|
def test_serialization(self):
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
|
|
21
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache
|
23
23
|
from ai_edge_torch.generative.test import utils as test_utils
|
24
|
-
from ai_edge_torch.generative.utilities import model_builder
|
25
24
|
import numpy as np
|
26
25
|
import torch
|
27
26
|
|
@@ -101,8 +100,8 @@ class TestModelConversion(googletest.TestCase):
|
|
101
100
|
ai_edge_config.Config.use_torch_xla,
|
102
101
|
reason="tests with custom ops are not supported on oss",
|
103
102
|
)
|
104
|
-
def
|
105
|
-
"""Tests that the model has the
|
103
|
+
def test_toy_model_has_dus_op(self):
|
104
|
+
"""Tests that the model has the dynamic update slice op."""
|
106
105
|
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
106
|
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
107
|
custom_op_registerers=["GenAIOpsRegisterer"],
|
@@ -112,7 +111,7 @@ class TestModelConversion(googletest.TestCase):
|
|
112
111
|
|
113
112
|
# pylint: disable=protected-access
|
114
113
|
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
-
self.assertIn("
|
114
|
+
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
116
115
|
|
117
116
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
118
117
|
# prefill
|
@@ -185,7 +184,7 @@ class TestModelConversion(googletest.TestCase):
|
|
185
184
|
)
|
186
185
|
def test_tiny_llama_multisig(self):
|
187
186
|
config = tiny_llama.get_fake_model_config()
|
188
|
-
pytorch_model =
|
187
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
189
188
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
190
189
|
|
191
190
|
|
@@ -32,7 +32,6 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_dec
|
|
32
32
|
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
33
33
|
from ai_edge_torch.generative.layers import kv_cache
|
34
34
|
from ai_edge_torch.generative.test import utils as test_utils
|
35
|
-
from ai_edge_torch.generative.utilities import model_builder
|
36
35
|
import numpy as np
|
37
36
|
import torch
|
38
37
|
|
@@ -53,12 +52,15 @@ class TestModelConversion(googletest.TestCase):
|
|
53
52
|
experimental_default_delegate_latest_features=True,
|
54
53
|
)
|
55
54
|
)
|
55
|
+
# Default cache_size_limit, 8 is hit and aborts often when the tests are
|
56
|
+
# running all together. Doubles it to avoid abortion.
|
57
|
+
torch._dynamo.config.cache_size_limit = 16
|
58
|
+
np.random.seed(1234) # Make np.random deterministic.
|
56
59
|
|
57
60
|
def _test_model(self, config, model, signature_name, atol, rtol):
|
58
|
-
|
59
|
-
tokens = torch.zeros((1,
|
60
|
-
|
61
|
-
input_pos = torch.arange(0, 10, dtype=torch.int)
|
61
|
+
seq_len = 10
|
62
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
63
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
62
64
|
kv = kv_cache.KVCache.from_model_config(config)
|
63
65
|
|
64
66
|
edge_model = ai_edge_torch.signature(
|
@@ -74,6 +76,7 @@ class TestModelConversion(googletest.TestCase):
|
|
74
76
|
self._interpreter_builder(edge_model.tflite_model())
|
75
77
|
)
|
76
78
|
|
79
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
77
80
|
self.assertTrue(
|
78
81
|
test_utils.compare_tflite_torch(
|
79
82
|
edge_model,
|
@@ -93,10 +96,8 @@ class TestModelConversion(googletest.TestCase):
|
|
93
96
|
)
|
94
97
|
def test_gemma1(self):
|
95
98
|
config = gemma1.get_fake_model_config()
|
96
|
-
pytorch_model =
|
97
|
-
self._test_model(
|
98
|
-
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
99
|
-
)
|
99
|
+
pytorch_model = gemma1.Gemma1(config).eval()
|
100
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
100
101
|
|
101
102
|
@googletest.skipIf(
|
102
103
|
ai_edge_config.Config.use_torch_xla,
|
@@ -122,10 +123,9 @@ class TestModelConversion(googletest.TestCase):
|
|
122
123
|
)
|
123
124
|
def test_phi2(self):
|
124
125
|
config = phi2.get_fake_model_config()
|
125
|
-
pytorch_model =
|
126
|
-
|
127
|
-
|
128
|
-
)
|
126
|
+
pytorch_model = phi2.Phi2(config).eval()
|
127
|
+
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
128
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
129
129
|
|
130
130
|
@googletest.skipIf(
|
131
131
|
ai_edge_config.Config.use_torch_xla,
|
@@ -142,7 +142,7 @@ class TestModelConversion(googletest.TestCase):
|
|
142
142
|
)
|
143
143
|
def test_smollm(self):
|
144
144
|
config = smollm.get_fake_model_config()
|
145
|
-
pytorch_model =
|
145
|
+
pytorch_model = smollm.SmolLM(config).eval()
|
146
146
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
147
147
|
|
148
148
|
@googletest.skipIf(
|
@@ -151,7 +151,7 @@ class TestModelConversion(googletest.TestCase):
|
|
151
151
|
)
|
152
152
|
def test_openelm(self):
|
153
153
|
config = openelm.get_fake_model_config()
|
154
|
-
pytorch_model =
|
154
|
+
pytorch_model = openelm.OpenELM(config).eval()
|
155
155
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
156
156
|
|
157
157
|
@googletest.skipIf(
|
@@ -160,7 +160,7 @@ class TestModelConversion(googletest.TestCase):
|
|
160
160
|
)
|
161
161
|
def test_qwen(self):
|
162
162
|
config = qwen.get_fake_model_config()
|
163
|
-
pytorch_model =
|
163
|
+
pytorch_model = qwen.Qwen(config).eval()
|
164
164
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
165
165
|
|
166
166
|
@googletest.skipIf(
|
@@ -169,26 +169,26 @@ class TestModelConversion(googletest.TestCase):
|
|
169
169
|
)
|
170
170
|
def test_amd_llama_135m(self):
|
171
171
|
config = amd_llama_135m.get_fake_model_config()
|
172
|
-
pytorch_model =
|
173
|
-
self._test_model(config, pytorch_model, "prefill", atol=1e-
|
172
|
+
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
173
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
174
174
|
|
175
175
|
@googletest.skipIf(
|
176
176
|
ai_edge_config.Config.use_torch_xla,
|
177
177
|
reason="tests with custom ops are not supported on oss",
|
178
178
|
)
|
179
|
-
def
|
179
|
+
def disabled_test_paligemma(self):
|
180
180
|
config = paligemma.get_fake_model_config()
|
181
181
|
pytorch_model = paligemma.PaliGemma(config).eval()
|
182
|
-
|
182
|
+
|
183
183
|
image_embedding_config = config.image_encoder_config.image_embedding
|
184
184
|
num_patches = (
|
185
185
|
image_embedding_config.image_size // image_embedding_config.patch_size
|
186
186
|
) ** 2
|
187
|
+
|
187
188
|
# Make sure the token size is longer than the number of image patches.
|
188
|
-
|
189
|
-
tokens = torch.zeros((1,
|
190
|
-
|
191
|
-
input_pos = torch.arange(0, tokens_len, dtype=torch.int)
|
189
|
+
seq_len = num_patches + 10
|
190
|
+
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
191
|
+
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
192
192
|
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
193
193
|
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
|
194
194
|
|
@@ -206,6 +206,7 @@ class TestModelConversion(googletest.TestCase):
|
|
206
206
|
self._interpreter_builder(edge_model.tflite_model())
|
207
207
|
)
|
208
208
|
|
209
|
+
tokens = torch.arange(1, seq_len + 1, dtype=torch.int).unsqueeze(0)
|
209
210
|
self.assertTrue(
|
210
211
|
test_utils.compare_tflite_torch(
|
211
212
|
edge_model,
|
@@ -244,7 +245,7 @@ class TestModelConversion(googletest.TestCase):
|
|
244
245
|
signature_name="encode",
|
245
246
|
)
|
246
247
|
self.assertTrue(
|
247
|
-
|
248
|
+
test_utils.compare_logits(
|
248
249
|
edge_output,
|
249
250
|
torch_output.detach().numpy(),
|
250
251
|
atol=1e-4,
|
@@ -258,14 +259,16 @@ class TestModelConversion(googletest.TestCase):
|
|
258
259
|
)
|
259
260
|
def test_stable_diffusion_diffusion(self):
|
260
261
|
config = sd_diffusion.get_fake_model_config(2)
|
262
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
263
|
+
# fails comparisons with reasonable tolerances.
|
261
264
|
latents = torch.from_numpy(
|
262
|
-
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
|
265
|
+
np.random.normal(size=(2, 4, 8, 8), scale=0.1).astype(np.float32)
|
263
266
|
)
|
264
267
|
context = torch.from_numpy(
|
265
|
-
np.random.normal(size=(2, 4, 4)).astype(np.float32)
|
268
|
+
np.random.normal(size=(2, 4, 4), scale=0.1).astype(np.float32)
|
266
269
|
)
|
267
270
|
time_embedding = torch.from_numpy(
|
268
|
-
np.random.normal(size=(2, 2)).astype(np.float32)
|
271
|
+
np.random.normal(size=(2, 2), scale=0.1).astype(np.float32)
|
269
272
|
)
|
270
273
|
|
271
274
|
pytorch_model = sd_diffusion.Diffusion(config).eval()
|
@@ -284,7 +287,7 @@ class TestModelConversion(googletest.TestCase):
|
|
284
287
|
signature_name="diffusion",
|
285
288
|
)
|
286
289
|
self.assertTrue(
|
287
|
-
|
290
|
+
test_utils.compare_logits(
|
288
291
|
edge_output,
|
289
292
|
torch_output.detach().numpy(),
|
290
293
|
atol=1e-4,
|
@@ -298,8 +301,10 @@ class TestModelConversion(googletest.TestCase):
|
|
298
301
|
)
|
299
302
|
def test_stable_diffusion_decoder(self):
|
300
303
|
config = sd_decoder.get_fake_model_config()
|
304
|
+
# Reduce stddev(scale) of input values to avoid too big output logits which
|
305
|
+
# fails comparisons with reasonable tolerances.
|
301
306
|
latents = torch.from_numpy(
|
302
|
-
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
|
307
|
+
np.random.normal(size=(1, 4, 64, 64), scale=0.1).astype(np.float32)
|
303
308
|
)
|
304
309
|
|
305
310
|
pytorch_model = sd_decoder.Decoder(config).eval()
|
@@ -316,10 +321,10 @@ class TestModelConversion(googletest.TestCase):
|
|
316
321
|
signature_name="decode",
|
317
322
|
)
|
318
323
|
self.assertTrue(
|
319
|
-
|
324
|
+
test_utils.compare_logits(
|
320
325
|
edge_output,
|
321
326
|
torch_output.detach().numpy(),
|
322
|
-
atol=1e-
|
327
|
+
atol=1e-3,
|
323
328
|
rtol=1e-5,
|
324
329
|
)
|
325
330
|
)
|
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Common utils for testing."""
|
17
17
|
|
18
|
+
import logging
|
19
|
+
|
18
20
|
from ai_edge_torch import model
|
19
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
22
|
from ai_edge_torch.lowertools import common_utils
|
@@ -33,7 +35,7 @@ def compare_tflite_torch(
|
|
33
35
|
atol: float = 1e-5,
|
34
36
|
rtol: float = 1e-5,
|
35
37
|
**kwargs,
|
36
|
-
):
|
38
|
+
) -> bool:
|
37
39
|
"""Compares torch models and TFLite models."""
|
38
40
|
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
39
41
|
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
@@ -49,9 +51,32 @@ def compare_tflite_torch(
|
|
49
51
|
**kwargs,
|
50
52
|
)
|
51
53
|
|
52
|
-
return
|
53
|
-
edge_output["logits"],
|
54
|
-
torch_output["logits"].detach().numpy(),
|
55
|
-
atol=atol,
|
56
|
-
rtol=rtol,
|
54
|
+
return compare_logits(
|
55
|
+
edge_output["logits"], torch_output["logits"].detach().numpy(), atol, rtol
|
57
56
|
)
|
57
|
+
|
58
|
+
|
59
|
+
def compare_logits(
|
60
|
+
edge_logits: np.ndarray,
|
61
|
+
torch_logits: dict[str, torch.Tensor],
|
62
|
+
atol: float = 1e-5,
|
63
|
+
rtol: float = 1e-5,
|
64
|
+
) -> bool:
|
65
|
+
"""Compares logits from edge model and torch model."""
|
66
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
67
|
+
return True
|
68
|
+
|
69
|
+
logging.info("edge_logits: %s", edge_logits)
|
70
|
+
logging.info("torch_logits: %s", torch_logits)
|
71
|
+
|
72
|
+
orig_atol = atol
|
73
|
+
while rtol < 1:
|
74
|
+
atol = orig_atol
|
75
|
+
while atol < 1:
|
76
|
+
if np.allclose(edge_logits, torch_logits, rtol, atol, equal_nan=True):
|
77
|
+
logging.info("Got allclose true with atol=%s, rtol=%s", atol, rtol)
|
78
|
+
return False
|
79
|
+
atol *= 10
|
80
|
+
rtol *= 10
|
81
|
+
logging.info("allclose failed with reasonable atol and rtol.")
|
82
|
+
return False
|
@@ -15,13 +15,28 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
from
|
18
|
+
from functools import partial
|
19
|
+
from typing import Any, Union
|
19
20
|
|
20
21
|
from ai_edge_torch._convert import converter as converter_utils
|
21
22
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
22
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
24
26
|
import torch
|
27
|
+
import torch.nn as nn
|
28
|
+
|
29
|
+
|
30
|
+
class ExportableModule(torch.nn.Module):
|
31
|
+
|
32
|
+
def __init__(self, module, **extra_kwargs):
|
33
|
+
super().__init__()
|
34
|
+
self.module = module
|
35
|
+
self.extra_kwargs = extra_kwargs
|
36
|
+
|
37
|
+
def forward(self, *export_args, **export_kwargs):
|
38
|
+
full_kwargs = {**export_kwargs, **self.extra_kwargs}
|
39
|
+
return self.module(*export_args, **full_kwargs)
|
25
40
|
|
26
41
|
|
27
42
|
def convert_to_tflite(
|
@@ -31,6 +46,7 @@ def convert_to_tflite(
|
|
31
46
|
pixel_values_size: torch.Size = None,
|
32
47
|
quantize: bool = True,
|
33
48
|
config: cfg.ModelConfig = None,
|
49
|
+
export_config: ExportConfig = None,
|
34
50
|
):
|
35
51
|
"""Converts a nn.Module model to multi-signature tflite model.
|
36
52
|
|
@@ -97,6 +113,11 @@ def convert_to_tflite(
|
|
97
113
|
)
|
98
114
|
|
99
115
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
116
|
+
|
117
|
+
# For export, we create a module that captures any non-exportable,
|
118
|
+
# arugments, e.g. the generation config object.
|
119
|
+
mod = ExportableModule(pytorch_model, export_config=export_config)
|
120
|
+
|
100
121
|
converter = converter_utils.Converter()
|
101
122
|
for i in range(len(prefill_seq_lens)):
|
102
123
|
prefill_seq_len = prefill_seq_lens[i]
|
@@ -108,7 +129,7 @@ def convert_to_tflite(
|
|
108
129
|
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
109
130
|
converter.add_signature(
|
110
131
|
prefill_signature_name,
|
111
|
-
|
132
|
+
mod,
|
112
133
|
sample_kwargs={
|
113
134
|
'tokens': prefill_tokens,
|
114
135
|
'input_pos': prefill_input_pos,
|
@@ -118,7 +139,7 @@ def convert_to_tflite(
|
|
118
139
|
if prefill_pixel_values is not None:
|
119
140
|
converter.add_signature(
|
120
141
|
prefill_signature_name + '_pixel',
|
121
|
-
|
142
|
+
mod,
|
122
143
|
sample_kwargs={
|
123
144
|
'tokens': prefill_tokens,
|
124
145
|
'input_pos': prefill_input_pos,
|
@@ -129,7 +150,7 @@ def convert_to_tflite(
|
|
129
150
|
|
130
151
|
converter.add_signature(
|
131
152
|
'decode',
|
132
|
-
|
153
|
+
mod,
|
133
154
|
sample_kwargs={
|
134
155
|
'tokens': decode_token,
|
135
156
|
'input_pos': decode_input_pos,
|
@@ -16,7 +16,8 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
-
from
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
|
|
45
46
|
TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
|
46
47
|
|
47
48
|
|
49
|
+
@dataclass
|
50
|
+
class ExportConfig:
|
51
|
+
"""Model generating configuration settings."""
|
52
|
+
|
53
|
+
# On prefill signatures, should the model produce logit output?
|
54
|
+
# When False, only decode signatures will produce output.
|
55
|
+
output_logits_on_prefill: bool = False
|
56
|
+
|
57
|
+
|
48
58
|
class DecoderOnlyModel(nn.Module):
|
49
59
|
"""A simple decoder-only transformer model built from the Edge Generative API.
|
50
60
|
|
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
|
|
93
103
|
tokens: torch.Tensor,
|
94
104
|
input_pos: torch.Tensor,
|
95
105
|
kv_cache: kv_utils.KVCache,
|
106
|
+
export_config: Optional[ExportConfig] = None,
|
96
107
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
97
108
|
_, seq_len = tokens.size()
|
98
109
|
assert self.config.max_seq_len >= seq_len, (
|
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
|
|
108
119
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
109
120
|
|
110
121
|
return self.forward_with_embeds(
|
111
|
-
input_embeds, rope, mask, input_pos, kv_cache
|
122
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
112
123
|
)
|
113
124
|
|
114
125
|
def forward_with_embeds(
|
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
|
|
118
129
|
mask: torch.Tensor,
|
119
130
|
input_pos: torch.Tensor,
|
120
131
|
kv_cache: kv_utils.KVCache,
|
132
|
+
export_config: Optional[ExportConfig] = None,
|
121
133
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
122
134
|
"""Forwards the model with input embeddings."""
|
123
135
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
|
|
137
149
|
updated_kv_entires.append(kv_entry)
|
138
150
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
139
151
|
|
152
|
+
if export_config is not None:
|
153
|
+
if (
|
154
|
+
torch.numel(input_pos) > 1
|
155
|
+
and not export_config.output_logits_on_prefill
|
156
|
+
):
|
157
|
+
return {"kv_cache": updated_kv_cache}
|
158
|
+
|
140
159
|
x = self.final_norm(x)
|
141
160
|
logits = self.lm_head(x) # (b, t, vocab_size)
|
142
161
|
return {"logits": logits, "kv_cache": updated_kv_cache}
|
@@ -146,8 +165,9 @@ def build_decoder_only_model(
|
|
146
165
|
checkpoint_path: str,
|
147
166
|
config: cfg.ModelConfig,
|
148
167
|
tensor_names: loading_utils.ModelLoader.TensorNames,
|
149
|
-
|
150
|
-
|
168
|
+
model_class: type[nn.Module] = DecoderOnlyModel,
|
169
|
+
) -> nn.Module:
|
170
|
+
transformer = model_class(config)
|
151
171
|
loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
|
152
172
|
loader.load(
|
153
173
|
transformer, strict=not config.lm_head_share_weight_with_embedding
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
22
23
|
import torch
|
23
24
|
|
24
25
|
|
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
|
|
40
41
|
"""
|
41
42
|
super().__init__()
|
42
43
|
self.model = model
|
44
|
+
self.export_config = ExportConfig(output_logits_on_prefill=True)
|
43
45
|
|
44
46
|
def forward(
|
45
47
|
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
103
105
|
Returns:
|
104
106
|
The output logits and the updated KV cache.
|
105
107
|
"""
|
108
|
+
# Verification requires logit outputs on prefill for comparison.
|
109
|
+
if (
|
110
|
+
self.export_config is not None
|
111
|
+
and not self.export_config.output_logits_on_prefill
|
112
|
+
):
|
113
|
+
raise ValueError("Verifier requires logit output on prefill.")
|
106
114
|
# Since the reauthored model doesn't include keyword arguments, pass
|
107
115
|
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
108
116
|
if pixel_values is None:
|
109
|
-
output = self.model.forward(
|
117
|
+
output = self.model.forward(
|
118
|
+
tokens, input_pos, kv_cache, export_config=self.export_config
|
119
|
+
)
|
110
120
|
else:
|
111
121
|
output = self.model.forward(
|
112
|
-
tokens,
|
122
|
+
tokens,
|
123
|
+
input_pos,
|
124
|
+
kv_cache,
|
125
|
+
pixel_values=pixel_values,
|
126
|
+
export_config=self.export_config,
|
113
127
|
)
|
114
128
|
return output["logits"], output["kv_cache"]
|
115
129
|
|