ai-edge-torch-nightly 0.2.0.dev20240718__py3-none-any.whl → 0.2.0.dev20240720__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion_utils.py +39 -18
- ai_edge_torch/convert/test/test_convert.py +106 -0
- ai_edge_torch/generative/examples/experimental/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +87 -0
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +195 -0
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +84 -0
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +184 -0
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +89 -0
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +185 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +6 -2
- ai_edge_torch/generative/examples/phi2/phi2.py +5 -2
- ai_edge_torch/generative/examples/t5/t5.py +5 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +42 -27
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +6 -2
- ai_edge_torch/generative/test/test_experimental_ekv.py +122 -0
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/RECORD +23 -12
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ import gc
|
|
|
20
20
|
import itertools
|
|
21
21
|
import logging
|
|
22
22
|
import tempfile
|
|
23
|
-
from typing import Any, Dict, Optional, Tuple, Union
|
|
23
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
24
24
|
|
|
25
25
|
import torch
|
|
26
26
|
import torch.utils._pytree as pytree
|
|
@@ -79,28 +79,49 @@ class Signature:
|
|
|
79
79
|
for i in range(args_spec.num_leaves):
|
|
80
80
|
names.append(f"args_{i}")
|
|
81
81
|
|
|
82
|
-
|
|
83
|
-
kwargs_spec.context
|
|
84
|
-
if kwargs_spec.type is not collections.defaultdict
|
|
85
|
-
# ignore mismatch of `default_factory` for defaultdict
|
|
86
|
-
else kwargs_spec.context[1]
|
|
82
|
+
kwargs_names = self._flat_kwarg_names(
|
|
83
|
+
kwargs_spec.children_specs, kwargs_spec.context
|
|
87
84
|
)
|
|
85
|
+
names.extend(kwargs_names)
|
|
86
|
+
return names
|
|
88
87
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
88
|
+
def _flat_kwarg_names(self, specs, context) -> List[str]:
|
|
89
|
+
flat_names = []
|
|
90
|
+
if context is None:
|
|
91
|
+
for i, spec in enumerate(specs):
|
|
92
|
+
if spec.children_specs:
|
|
93
|
+
flat_names.extend(
|
|
94
|
+
[
|
|
95
|
+
f"{i}_{name}"
|
|
96
|
+
for name in self._flat_kwarg_names(spec.children_specs, spec.context)
|
|
97
|
+
]
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
flat_names.append(f"{i}")
|
|
101
|
+
else:
|
|
102
|
+
flat_ctx = self._flatten_list(context)
|
|
103
|
+
for prefix, spec in zip(flat_ctx, specs):
|
|
104
|
+
leaf_flat_names = self._flat_kwarg_names(spec.children_specs, spec.context)
|
|
105
|
+
if leaf_flat_names:
|
|
106
|
+
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
107
|
+
else:
|
|
108
|
+
flat_names.append(prefix)
|
|
109
|
+
|
|
110
|
+
return flat_names
|
|
111
|
+
|
|
112
|
+
def _flatten_list(self, l: List) -> List:
|
|
113
|
+
flattened = []
|
|
114
|
+
for item in l:
|
|
115
|
+
if isinstance(item, list):
|
|
116
|
+
flattened.extend(self._flatten_list(item))
|
|
92
117
|
else:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
# tensor containers as inputs.
|
|
96
|
-
# TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
|
|
97
|
-
for i in range(value_spec.num_leaves):
|
|
98
|
-
names.append(f"{name}_{i}")
|
|
99
|
-
return names
|
|
118
|
+
flattened.append(item)
|
|
119
|
+
return flattened
|
|
100
120
|
|
|
101
121
|
@property
|
|
102
|
-
def flat_args(self) -> tuple[
|
|
103
|
-
|
|
122
|
+
def flat_args(self) -> tuple[Any]:
|
|
123
|
+
args, kwargs = self._normalized_sample_args_kwargs
|
|
124
|
+
return tuple([*args, *kwargs.values()])
|
|
104
125
|
|
|
105
126
|
|
|
106
127
|
def exported_program_to_stablehlo_bundle(
|
|
@@ -14,10 +14,14 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
from dataclasses import dataclass
|
|
17
18
|
import os
|
|
18
19
|
import tempfile
|
|
20
|
+
from typing import Tuple
|
|
19
21
|
import unittest
|
|
20
22
|
|
|
23
|
+
import numpy as np
|
|
24
|
+
import tensorflow as tf
|
|
21
25
|
import torch
|
|
22
26
|
import torchvision
|
|
23
27
|
|
|
@@ -26,6 +30,15 @@ from ai_edge_torch.convert import conversion_utils as cutils
|
|
|
26
30
|
from ai_edge_torch.testing import model_coverage
|
|
27
31
|
|
|
28
32
|
|
|
33
|
+
@dataclass
|
|
34
|
+
class TestContainer1:
|
|
35
|
+
data_1: torch.Tensor
|
|
36
|
+
data_2: Tuple[torch.Tensor, torch.Tensor]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
torch.export.register_dataclass(TestContainer1, serialized_type_name="TestContainer1")
|
|
40
|
+
|
|
41
|
+
|
|
29
42
|
class TestConvert(unittest.TestCase):
|
|
30
43
|
"""Tests conversion of various modules."""
|
|
31
44
|
|
|
@@ -306,6 +319,99 @@ class TestConvert(unittest.TestCase):
|
|
|
306
319
|
model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
|
|
307
320
|
)
|
|
308
321
|
|
|
322
|
+
def test_convert_model_with_args_nested_kwargs_1(self):
|
|
323
|
+
"""
|
|
324
|
+
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
class SampleModel(torch.nn.Module):
|
|
328
|
+
|
|
329
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
|
|
330
|
+
return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
|
|
331
|
+
|
|
332
|
+
args = (torch.randn(10, 10),)
|
|
333
|
+
kwargs = dict(
|
|
334
|
+
y=torch.randn(10, 10),
|
|
335
|
+
z=TestContainer1(
|
|
336
|
+
data_1=torch.randn(10, 10),
|
|
337
|
+
data_2=(torch.randn(10, 10), torch.randn(10, 10)),
|
|
338
|
+
),
|
|
339
|
+
)
|
|
340
|
+
flat_inputs = {
|
|
341
|
+
"args_0": args[0].numpy(),
|
|
342
|
+
"y": kwargs["y"].numpy(),
|
|
343
|
+
"z_data_1": kwargs["z"].data_1.numpy(),
|
|
344
|
+
"z_data_2_0": kwargs["z"].data_2[0].numpy(),
|
|
345
|
+
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
346
|
+
}
|
|
347
|
+
self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
|
|
348
|
+
|
|
349
|
+
def test_convert_model_with_args_nested_kwargs_2(self):
|
|
350
|
+
"""
|
|
351
|
+
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
class SampleModel(torch.nn.Module):
|
|
355
|
+
|
|
356
|
+
def forward(self, x, y, z):
|
|
357
|
+
return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
|
|
358
|
+
|
|
359
|
+
args = (torch.randn(10, 10),)
|
|
360
|
+
kwargs = dict(
|
|
361
|
+
y=torch.randn(10, 10),
|
|
362
|
+
z=TestContainer1(
|
|
363
|
+
data_1=torch.randn(10, 10),
|
|
364
|
+
data_2=[(torch.randn(10, 10),), torch.randn(10, 10)],
|
|
365
|
+
),
|
|
366
|
+
)
|
|
367
|
+
flat_inputs = {
|
|
368
|
+
"args_0": args[0].numpy(),
|
|
369
|
+
"y": kwargs["y"].numpy(),
|
|
370
|
+
"z_data_1": kwargs["z"].data_1.numpy(),
|
|
371
|
+
"z_data_2_0_0": kwargs["z"].data_2[0][0].numpy(),
|
|
372
|
+
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
373
|
+
}
|
|
374
|
+
self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
|
|
375
|
+
|
|
376
|
+
def test_convert_model_with_args_nested_kwargs_3(self):
|
|
377
|
+
"""
|
|
378
|
+
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
class SampleModel(torch.nn.Module):
|
|
382
|
+
|
|
383
|
+
def forward(self, x, y, z):
|
|
384
|
+
return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
|
|
385
|
+
|
|
386
|
+
args = (torch.randn(10, 10),)
|
|
387
|
+
kwargs = dict(
|
|
388
|
+
y=torch.randn(10, 10),
|
|
389
|
+
z=TestContainer1(
|
|
390
|
+
data_1=torch.randn(10, 10),
|
|
391
|
+
data_2=(dict(foo=torch.randn(10, 10)), torch.randn(10, 10)),
|
|
392
|
+
),
|
|
393
|
+
)
|
|
394
|
+
flat_inputs = {
|
|
395
|
+
"args_0": args[0].numpy(),
|
|
396
|
+
"y": kwargs["y"].numpy(),
|
|
397
|
+
"z_data_1": kwargs["z"].data_1.numpy(),
|
|
398
|
+
"z_data_2_0_foo": kwargs["z"].data_2[0]["foo"].numpy(),
|
|
399
|
+
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
400
|
+
}
|
|
401
|
+
self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
|
|
402
|
+
|
|
403
|
+
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
|
|
404
|
+
model.eval()
|
|
405
|
+
edge_model = ai_edge_torch.convert(model, args, kwargs)
|
|
406
|
+
interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
|
|
407
|
+
runner = interpreter.get_signature_runner("serving_default")
|
|
408
|
+
input_details = runner.get_input_details()
|
|
409
|
+
self.assertEqual(input_details.keys(), flat_inputs.keys())
|
|
410
|
+
|
|
411
|
+
reference_output = model(*args, **kwargs)
|
|
412
|
+
tflite_output = edge_model(**flat_inputs)
|
|
413
|
+
np.testing.assert_almost_equal(reference_output, tflite_output)
|
|
414
|
+
|
|
309
415
|
|
|
310
416
|
if __name__ == "__main__":
|
|
311
417
|
unittest.main()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
#
|
|
16
|
+
# Note: This is an experimental version of Gemma with external KV cache.
|
|
17
|
+
# Please use with caution.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
import ai_edge_torch
|
|
26
|
+
from ai_edge_torch.generative.examples.experimental.gemma import gemma
|
|
27
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
28
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def convert_gemma_to_tflite(
|
|
32
|
+
checkpoint_path: str,
|
|
33
|
+
prefill_seq_len: int = 512,
|
|
34
|
+
kv_cache_max_len: int = 1024,
|
|
35
|
+
quantize: bool = True,
|
|
36
|
+
):
|
|
37
|
+
"""An example method for converting a Gemma 2B model to multi-signature
|
|
38
|
+
tflite model.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
42
|
+
holding the checkpoint.
|
|
43
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
44
|
+
Defaults to 512.
|
|
45
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
46
|
+
including both prefill and decode. Defaults to 1024.
|
|
47
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
48
|
+
Defaults to True.
|
|
49
|
+
"""
|
|
50
|
+
pytorch_model = gemma.build_2b_model(
|
|
51
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
52
|
+
)
|
|
53
|
+
# Tensors used to trace the model graph during conversion.
|
|
54
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
55
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
56
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
57
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
58
|
+
kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
|
|
59
|
+
|
|
60
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
|
61
|
+
edge_model = (
|
|
62
|
+
ai_edge_torch.signature(
|
|
63
|
+
'prefill',
|
|
64
|
+
pytorch_model,
|
|
65
|
+
sample_kwargs={
|
|
66
|
+
'tokens': prefill_tokens,
|
|
67
|
+
'input_pos': prefill_input_pos,
|
|
68
|
+
'kv_cache': kv,
|
|
69
|
+
},
|
|
70
|
+
)
|
|
71
|
+
.signature(
|
|
72
|
+
'decode',
|
|
73
|
+
pytorch_model,
|
|
74
|
+
sample_kwargs={
|
|
75
|
+
'tokens': decode_token,
|
|
76
|
+
'input_pos': decode_input_pos,
|
|
77
|
+
'kv_cache': kv,
|
|
78
|
+
},
|
|
79
|
+
)
|
|
80
|
+
.convert(quant_config=quant_config)
|
|
81
|
+
)
|
|
82
|
+
edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
if __name__ == '__main__':
|
|
86
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
|
|
87
|
+
convert_gemma_to_tflite(checkpoint_path)
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building a Gemma model.
|
|
16
|
+
#
|
|
17
|
+
# Note: This is an experimental version of Gemma with external KV cache.
|
|
18
|
+
# Please use with caution.
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Tuple
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
|
|
28
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
29
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
30
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
31
|
+
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
32
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
33
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
34
|
+
|
|
35
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
36
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
37
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
|
38
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
|
39
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
|
40
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
|
41
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
42
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
43
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
44
|
+
pre_ff_norm="model.layers.{}.post_attention_layernorm",
|
|
45
|
+
embedding="model.embed_tokens",
|
|
46
|
+
final_norm="model.norm",
|
|
47
|
+
lm_head=None,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Gemma(nn.Module):
|
|
52
|
+
|
|
53
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
54
|
+
super().__init__()
|
|
55
|
+
|
|
56
|
+
self.config = config
|
|
57
|
+
# Construct model layers.
|
|
58
|
+
self.tok_embedding = nn.Embedding(
|
|
59
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
60
|
+
)
|
|
61
|
+
self.lm_head = nn.Linear(
|
|
62
|
+
config.embedding_dim,
|
|
63
|
+
config.vocab_size,
|
|
64
|
+
bias=config.lm_head_use_bias,
|
|
65
|
+
)
|
|
66
|
+
# Gemma re-uses the embedding as the head projection layer.
|
|
67
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
|
68
|
+
self.transformer_blocks = nn.ModuleList(
|
|
69
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
70
|
+
)
|
|
71
|
+
self.final_norm = builder.build_norm(
|
|
72
|
+
config.embedding_dim,
|
|
73
|
+
config.final_norm_config,
|
|
74
|
+
)
|
|
75
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
76
|
+
size=config.kv_cache_max,
|
|
77
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
78
|
+
base=10_000,
|
|
79
|
+
condense_ratio=1,
|
|
80
|
+
dtype=torch.float32,
|
|
81
|
+
device=torch.device("cpu"),
|
|
82
|
+
)
|
|
83
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
84
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
85
|
+
)
|
|
86
|
+
self.config = config
|
|
87
|
+
|
|
88
|
+
@torch.inference_mode
|
|
89
|
+
def forward(
|
|
90
|
+
self,
|
|
91
|
+
tokens: torch.Tensor,
|
|
92
|
+
input_pos: torch.Tensor,
|
|
93
|
+
kv_cache: kv_utils.EKVCache,
|
|
94
|
+
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
95
|
+
B, T = tokens.size()
|
|
96
|
+
assert (
|
|
97
|
+
self.config.max_seq_len >= T
|
|
98
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
99
|
+
|
|
100
|
+
cos, sin = self.rope_cache
|
|
101
|
+
cos = cos.index_select(0, input_pos)
|
|
102
|
+
sin = sin.index_select(0, input_pos)
|
|
103
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
104
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
|
105
|
+
|
|
106
|
+
# token embeddings of shape (b, t, n_embd)
|
|
107
|
+
x = self.tok_embedding(tokens)
|
|
108
|
+
x = x * (self.config.embedding_dim**0.5)
|
|
109
|
+
|
|
110
|
+
updated_kv_entires = []
|
|
111
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
112
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
|
113
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
|
114
|
+
if kv_entry:
|
|
115
|
+
updated_kv_entires.append(kv_entry)
|
|
116
|
+
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
|
117
|
+
|
|
118
|
+
x = self.final_norm(x)
|
|
119
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
|
120
|
+
return res, updated_kv_cache
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
124
|
+
attn_config = cfg.AttentionConfig(
|
|
125
|
+
num_heads=8,
|
|
126
|
+
num_query_groups=1,
|
|
127
|
+
rotary_percentage=1.0,
|
|
128
|
+
)
|
|
129
|
+
ff_config = cfg.FeedForwardConfig(
|
|
130
|
+
type=cfg.FeedForwardType.GATED,
|
|
131
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
132
|
+
intermediate_size=16384,
|
|
133
|
+
)
|
|
134
|
+
norm_config = cfg.NormalizationConfig(
|
|
135
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
136
|
+
epsilon=1e-6,
|
|
137
|
+
zero_centered=True,
|
|
138
|
+
)
|
|
139
|
+
config = cfg.ModelConfig(
|
|
140
|
+
vocab_size=256000,
|
|
141
|
+
num_layers=18,
|
|
142
|
+
max_seq_len=8192,
|
|
143
|
+
embedding_dim=2048,
|
|
144
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
145
|
+
attn_config=attn_config,
|
|
146
|
+
ff_config=ff_config,
|
|
147
|
+
pre_attention_norm_config=norm_config,
|
|
148
|
+
pre_ff_norm_config=norm_config,
|
|
149
|
+
final_norm_config=norm_config,
|
|
150
|
+
parallel_residual=False,
|
|
151
|
+
lm_head_use_bias=False,
|
|
152
|
+
enable_hlfb=True,
|
|
153
|
+
)
|
|
154
|
+
return config
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def get_fake_model_config_2b_for_test(**kwargs) -> cfg.ModelConfig:
|
|
158
|
+
config = get_model_config_2b(**kwargs)
|
|
159
|
+
config.num_layers = 2
|
|
160
|
+
return config
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
164
|
+
config = (
|
|
165
|
+
get_fake_model_config_2b_for_test(**kwargs)
|
|
166
|
+
if test_model
|
|
167
|
+
else get_model_config_2b(**kwargs)
|
|
168
|
+
)
|
|
169
|
+
model = Gemma(config)
|
|
170
|
+
if checkpoint_path is not None:
|
|
171
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
|
172
|
+
# since embedding and lm-head use the same weight, we need to set strict
|
|
173
|
+
# to False.
|
|
174
|
+
loader.load(model, strict=False)
|
|
175
|
+
model.eval()
|
|
176
|
+
return model
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def define_and_run_2b(checkpoint_path, test_model=False) -> None:
|
|
180
|
+
kv_cache_max_len = 1024
|
|
181
|
+
model = build_2b_model(
|
|
182
|
+
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
|
183
|
+
)
|
|
184
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
185
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
186
|
+
tokens[0, :4] = idx
|
|
187
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
|
188
|
+
kv = kv_utils.EKVCache.from_model_config(model.config)
|
|
189
|
+
print("running an inference")
|
|
190
|
+
print(model.forward(tokens, input_pos, kv))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
if __name__ == "__main__":
|
|
194
|
+
checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
|
|
195
|
+
define_and_run_2b(checkpoint_path)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
#
|
|
16
|
+
# Note: This is an experimental version of phi2 with external KV cache.
|
|
17
|
+
# Please use with caution.
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
import ai_edge_torch
|
|
25
|
+
from ai_edge_torch.generative.examples.experimental.phi import phi2
|
|
26
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache
|
|
27
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def convert_phi2_to_tflite(
|
|
31
|
+
checkpoint_path: str,
|
|
32
|
+
prefill_seq_len: int = 512,
|
|
33
|
+
kv_cache_max_len: int = 1024,
|
|
34
|
+
quantize: bool = True,
|
|
35
|
+
):
|
|
36
|
+
"""An example method for converting a Phi-2 model to multi-signature
|
|
37
|
+
tflite model.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
checkpoint_path (str): The filepath to the model checkpoint, or
|
|
41
|
+
directory holding the checkpoint.
|
|
42
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
43
|
+
Defaults to 512.
|
|
44
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
45
|
+
including both prefill and decode. Defaults to 1024.
|
|
46
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
47
|
+
Defaults to True.
|
|
48
|
+
"""
|
|
49
|
+
pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
50
|
+
# Tensors used to trace the model graph during conversion.
|
|
51
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
52
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
53
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
54
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
55
|
+
kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
|
|
56
|
+
|
|
57
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
|
58
|
+
edge_model = (
|
|
59
|
+
ai_edge_torch.signature(
|
|
60
|
+
'prefill',
|
|
61
|
+
pytorch_model,
|
|
62
|
+
sample_kwargs={
|
|
63
|
+
'tokens': prefill_tokens,
|
|
64
|
+
'input_pos': prefill_input_pos,
|
|
65
|
+
'kv_cache': kv,
|
|
66
|
+
},
|
|
67
|
+
)
|
|
68
|
+
.signature(
|
|
69
|
+
'decode',
|
|
70
|
+
pytorch_model,
|
|
71
|
+
sample_kwargs={
|
|
72
|
+
'tokens': decode_token,
|
|
73
|
+
'input_pos': decode_input_pos,
|
|
74
|
+
'kv_cache': kv,
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
.convert(quant_config=quant_config)
|
|
78
|
+
)
|
|
79
|
+
edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
if __name__ == '__main__':
|
|
83
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
|
|
84
|
+
convert_phi2_to_tflite(checkpoint_path)
|