ai-edge-torch-nightly 0.2.0.dev20240718__py3-none-any.whl → 0.2.0.dev20240720__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (23) hide show
  1. ai_edge_torch/convert/conversion_utils.py +39 -18
  2. ai_edge_torch/convert/test/test_convert.py +106 -0
  3. ai_edge_torch/generative/examples/experimental/__init__.py +14 -0
  4. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +14 -0
  5. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +87 -0
  6. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +195 -0
  7. ai_edge_torch/generative/examples/experimental/phi/__init__.py +14 -0
  8. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +84 -0
  9. ai_edge_torch/generative/examples/experimental/phi/phi2.py +184 -0
  10. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +14 -0
  11. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +89 -0
  12. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +185 -0
  13. ai_edge_torch/generative/examples/gemma/gemma.py +6 -2
  14. ai_edge_torch/generative/examples/phi2/phi2.py +5 -2
  15. ai_edge_torch/generative/examples/t5/t5.py +5 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +42 -27
  17. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +6 -2
  18. ai_edge_torch/generative/test/test_experimental_ekv.py +122 -0
  19. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/RECORD +23 -12
  21. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import gc
20
20
  import itertools
21
21
  import logging
22
22
  import tempfile
23
- from typing import Any, Dict, Optional, Tuple, Union
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
24
 
25
25
  import torch
26
26
  import torch.utils._pytree as pytree
@@ -79,28 +79,49 @@ class Signature:
79
79
  for i in range(args_spec.num_leaves):
80
80
  names.append(f"args_{i}")
81
81
 
82
- dict_context = (
83
- kwargs_spec.context
84
- if kwargs_spec.type is not collections.defaultdict
85
- # ignore mismatch of `default_factory` for defaultdict
86
- else kwargs_spec.context[1]
82
+ kwargs_names = self._flat_kwarg_names(
83
+ kwargs_spec.children_specs, kwargs_spec.context
87
84
  )
85
+ names.extend(kwargs_names)
86
+ return names
88
87
 
89
- for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
90
- if value_spec.num_leaves == 1:
91
- names.append(name)
88
+ def _flat_kwarg_names(self, specs, context) -> List[str]:
89
+ flat_names = []
90
+ if context is None:
91
+ for i, spec in enumerate(specs):
92
+ if spec.children_specs:
93
+ flat_names.extend(
94
+ [
95
+ f"{i}_{name}"
96
+ for name in self._flat_kwarg_names(spec.children_specs, spec.context)
97
+ ]
98
+ )
99
+ else:
100
+ flat_names.append(f"{i}")
101
+ else:
102
+ flat_ctx = self._flatten_list(context)
103
+ for prefix, spec in zip(flat_ctx, specs):
104
+ leaf_flat_names = self._flat_kwarg_names(spec.children_specs, spec.context)
105
+ if leaf_flat_names:
106
+ flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
107
+ else:
108
+ flat_names.append(prefix)
109
+
110
+ return flat_names
111
+
112
+ def _flatten_list(self, l: List) -> List:
113
+ flattened = []
114
+ for item in l:
115
+ if isinstance(item, list):
116
+ flattened.extend(self._flatten_list(item))
92
117
  else:
93
- # value_spec.num_leaves may be greater than 1 when the value is a (nested)
94
- # tuple of tensors. We haven't decided how we should support flattenable
95
- # tensor containers as inputs.
96
- # TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
97
- for i in range(value_spec.num_leaves):
98
- names.append(f"{name}_{i}")
99
- return names
118
+ flattened.append(item)
119
+ return flattened
100
120
 
101
121
  @property
102
- def flat_args(self) -> tuple[torch.Tensor]:
103
- return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
122
+ def flat_args(self) -> tuple[Any]:
123
+ args, kwargs = self._normalized_sample_args_kwargs
124
+ return tuple([*args, *kwargs.values()])
104
125
 
105
126
 
106
127
  def exported_program_to_stablehlo_bundle(
@@ -14,10 +14,14 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
+ from dataclasses import dataclass
17
18
  import os
18
19
  import tempfile
20
+ from typing import Tuple
19
21
  import unittest
20
22
 
23
+ import numpy as np
24
+ import tensorflow as tf
21
25
  import torch
22
26
  import torchvision
23
27
 
@@ -26,6 +30,15 @@ from ai_edge_torch.convert import conversion_utils as cutils
26
30
  from ai_edge_torch.testing import model_coverage
27
31
 
28
32
 
33
+ @dataclass
34
+ class TestContainer1:
35
+ data_1: torch.Tensor
36
+ data_2: Tuple[torch.Tensor, torch.Tensor]
37
+
38
+
39
+ torch.export.register_dataclass(TestContainer1, serialized_type_name="TestContainer1")
40
+
41
+
29
42
  class TestConvert(unittest.TestCase):
30
43
  """Tests conversion of various modules."""
31
44
 
@@ -306,6 +319,99 @@ class TestConvert(unittest.TestCase):
306
319
  model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
307
320
  )
308
321
 
322
+ def test_convert_model_with_args_nested_kwargs_1(self):
323
+ """
324
+ Test converting a simple model with both sample_args and nested sample_kwargs.
325
+ """
326
+
327
+ class SampleModel(torch.nn.Module):
328
+
329
+ def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
330
+ return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
331
+
332
+ args = (torch.randn(10, 10),)
333
+ kwargs = dict(
334
+ y=torch.randn(10, 10),
335
+ z=TestContainer1(
336
+ data_1=torch.randn(10, 10),
337
+ data_2=(torch.randn(10, 10), torch.randn(10, 10)),
338
+ ),
339
+ )
340
+ flat_inputs = {
341
+ "args_0": args[0].numpy(),
342
+ "y": kwargs["y"].numpy(),
343
+ "z_data_1": kwargs["z"].data_1.numpy(),
344
+ "z_data_2_0": kwargs["z"].data_2[0].numpy(),
345
+ "z_data_2_1": kwargs["z"].data_2[1].numpy(),
346
+ }
347
+ self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
348
+
349
+ def test_convert_model_with_args_nested_kwargs_2(self):
350
+ """
351
+ Test converting a simple model with both sample_args and nested sample_kwargs.
352
+ """
353
+
354
+ class SampleModel(torch.nn.Module):
355
+
356
+ def forward(self, x, y, z):
357
+ return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
358
+
359
+ args = (torch.randn(10, 10),)
360
+ kwargs = dict(
361
+ y=torch.randn(10, 10),
362
+ z=TestContainer1(
363
+ data_1=torch.randn(10, 10),
364
+ data_2=[(torch.randn(10, 10),), torch.randn(10, 10)],
365
+ ),
366
+ )
367
+ flat_inputs = {
368
+ "args_0": args[0].numpy(),
369
+ "y": kwargs["y"].numpy(),
370
+ "z_data_1": kwargs["z"].data_1.numpy(),
371
+ "z_data_2_0_0": kwargs["z"].data_2[0][0].numpy(),
372
+ "z_data_2_1": kwargs["z"].data_2[1].numpy(),
373
+ }
374
+ self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
375
+
376
+ def test_convert_model_with_args_nested_kwargs_3(self):
377
+ """
378
+ Test converting a simple model with both sample_args and nested sample_kwargs.
379
+ """
380
+
381
+ class SampleModel(torch.nn.Module):
382
+
383
+ def forward(self, x, y, z):
384
+ return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
385
+
386
+ args = (torch.randn(10, 10),)
387
+ kwargs = dict(
388
+ y=torch.randn(10, 10),
389
+ z=TestContainer1(
390
+ data_1=torch.randn(10, 10),
391
+ data_2=(dict(foo=torch.randn(10, 10)), torch.randn(10, 10)),
392
+ ),
393
+ )
394
+ flat_inputs = {
395
+ "args_0": args[0].numpy(),
396
+ "y": kwargs["y"].numpy(),
397
+ "z_data_1": kwargs["z"].data_1.numpy(),
398
+ "z_data_2_0_foo": kwargs["z"].data_2[0]["foo"].numpy(),
399
+ "z_data_2_1": kwargs["z"].data_2[1].numpy(),
400
+ }
401
+ self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
402
+
403
+ def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
404
+ model.eval()
405
+ edge_model = ai_edge_torch.convert(model, args, kwargs)
406
+ interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
407
+ runner = interpreter.get_signature_runner("serving_default")
408
+ input_details = runner.get_input_details()
409
+ self.assertEqual(input_details.keys(), flat_inputs.keys())
410
+
411
+ reference_output = model(*args, **kwargs)
412
+ tflite_output = edge_model(**flat_inputs)
413
+ np.testing.assert_almost_equal(reference_output, tflite_output)
414
+
309
415
 
310
416
  if __name__ == "__main__":
311
417
  unittest.main()
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,87 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Note: This is an experimental version of Gemma with external KV cache.
17
+ # Please use with caution.
18
+
19
+
20
+ import os
21
+ from pathlib import Path
22
+
23
+ import torch
24
+
25
+ import ai_edge_torch
26
+ from ai_edge_torch.generative.examples.experimental.gemma import gemma
27
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
+ from ai_edge_torch.generative.quantize import quant_recipes
29
+
30
+
31
+ def convert_gemma_to_tflite(
32
+ checkpoint_path: str,
33
+ prefill_seq_len: int = 512,
34
+ kv_cache_max_len: int = 1024,
35
+ quantize: bool = True,
36
+ ):
37
+ """An example method for converting a Gemma 2B model to multi-signature
38
+ tflite model.
39
+
40
+ Args:
41
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
42
+ holding the checkpoint.
43
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
44
+ Defaults to 512.
45
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
46
+ including both prefill and decode. Defaults to 1024.
47
+ quantize (bool, optional): Whether the model should be quanized.
48
+ Defaults to True.
49
+ """
50
+ pytorch_model = gemma.build_2b_model(
51
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
52
+ )
53
+ # Tensors used to trace the model graph during conversion.
54
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
55
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
56
+ decode_token = torch.tensor([[0]], dtype=torch.long)
57
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
58
+ kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
59
+
60
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
61
+ edge_model = (
62
+ ai_edge_torch.signature(
63
+ 'prefill',
64
+ pytorch_model,
65
+ sample_kwargs={
66
+ 'tokens': prefill_tokens,
67
+ 'input_pos': prefill_input_pos,
68
+ 'kv_cache': kv,
69
+ },
70
+ )
71
+ .signature(
72
+ 'decode',
73
+ pytorch_model,
74
+ sample_kwargs={
75
+ 'tokens': decode_token,
76
+ 'input_pos': decode_input_pos,
77
+ 'kv_cache': kv,
78
+ },
79
+ )
80
+ .convert(quant_config=quant_config)
81
+ )
82
+ edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
83
+
84
+
85
+ if __name__ == '__main__':
86
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
87
+ convert_gemma_to_tflite(checkpoint_path)
@@ -0,0 +1,195 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building a Gemma model.
16
+ #
17
+ # Note: This is an experimental version of Gemma with external KV cache.
18
+ # Please use with caution.
19
+
20
+ import os
21
+ from pathlib import Path
22
+ from typing import Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
29
+ import ai_edge_torch.generative.layers.builder as builder
30
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
31
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
32
+ import ai_edge_torch.generative.layers.model_config as cfg
33
+ import ai_edge_torch.generative.utilities.loader as loading_utils
34
+
35
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
+ ff_up_proj="model.layers.{}.mlp.up_proj",
37
+ ff_down_proj="model.layers.{}.mlp.down_proj",
38
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
39
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
40
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
41
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
42
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
43
+ pre_attn_norm="model.layers.{}.input_layernorm",
44
+ pre_ff_norm="model.layers.{}.post_attention_layernorm",
45
+ embedding="model.embed_tokens",
46
+ final_norm="model.norm",
47
+ lm_head=None,
48
+ )
49
+
50
+
51
+ class Gemma(nn.Module):
52
+
53
+ def __init__(self, config: cfg.ModelConfig):
54
+ super().__init__()
55
+
56
+ self.config = config
57
+ # Construct model layers.
58
+ self.tok_embedding = nn.Embedding(
59
+ config.vocab_size, config.embedding_dim, padding_idx=0
60
+ )
61
+ self.lm_head = nn.Linear(
62
+ config.embedding_dim,
63
+ config.vocab_size,
64
+ bias=config.lm_head_use_bias,
65
+ )
66
+ # Gemma re-uses the embedding as the head projection layer.
67
+ self.lm_head.weight.data = self.tok_embedding.weight.data
68
+ self.transformer_blocks = nn.ModuleList(
69
+ TransformerBlock(config) for _ in range(config.num_layers)
70
+ )
71
+ self.final_norm = builder.build_norm(
72
+ config.embedding_dim,
73
+ config.final_norm_config,
74
+ )
75
+ self.rope_cache = attn_utils.build_rope_cache(
76
+ size=config.kv_cache_max,
77
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
78
+ base=10_000,
79
+ condense_ratio=1,
80
+ dtype=torch.float32,
81
+ device=torch.device("cpu"),
82
+ )
83
+ self.mask_cache = attn_utils.build_causal_mask_cache(
84
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
85
+ )
86
+ self.config = config
87
+
88
+ @torch.inference_mode
89
+ def forward(
90
+ self,
91
+ tokens: torch.Tensor,
92
+ input_pos: torch.Tensor,
93
+ kv_cache: kv_utils.EKVCache,
94
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
95
+ B, T = tokens.size()
96
+ assert (
97
+ self.config.max_seq_len >= T
98
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
99
+
100
+ cos, sin = self.rope_cache
101
+ cos = cos.index_select(0, input_pos)
102
+ sin = sin.index_select(0, input_pos)
103
+ mask = self.mask_cache.index_select(2, input_pos)
104
+ mask = mask[:, :, :, : self.config.kv_cache_max]
105
+
106
+ # token embeddings of shape (b, t, n_embd)
107
+ x = self.tok_embedding(tokens)
108
+ x = x * (self.config.embedding_dim**0.5)
109
+
110
+ updated_kv_entires = []
111
+ for i, block in enumerate(self.transformer_blocks):
112
+ kv_entry = kv_cache.caches[i] if kv_cache else None
113
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
114
+ if kv_entry:
115
+ updated_kv_entires.append(kv_entry)
116
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
117
+
118
+ x = self.final_norm(x)
119
+ res = self.lm_head(x) # (b, t, vocab_size)
120
+ return res, updated_kv_cache
121
+
122
+
123
+ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
124
+ attn_config = cfg.AttentionConfig(
125
+ num_heads=8,
126
+ num_query_groups=1,
127
+ rotary_percentage=1.0,
128
+ )
129
+ ff_config = cfg.FeedForwardConfig(
130
+ type=cfg.FeedForwardType.GATED,
131
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
132
+ intermediate_size=16384,
133
+ )
134
+ norm_config = cfg.NormalizationConfig(
135
+ type=cfg.NormalizationType.RMS_NORM,
136
+ epsilon=1e-6,
137
+ zero_centered=True,
138
+ )
139
+ config = cfg.ModelConfig(
140
+ vocab_size=256000,
141
+ num_layers=18,
142
+ max_seq_len=8192,
143
+ embedding_dim=2048,
144
+ kv_cache_max_len=kv_cache_max_len,
145
+ attn_config=attn_config,
146
+ ff_config=ff_config,
147
+ pre_attention_norm_config=norm_config,
148
+ pre_ff_norm_config=norm_config,
149
+ final_norm_config=norm_config,
150
+ parallel_residual=False,
151
+ lm_head_use_bias=False,
152
+ enable_hlfb=True,
153
+ )
154
+ return config
155
+
156
+
157
+ def get_fake_model_config_2b_for_test(**kwargs) -> cfg.ModelConfig:
158
+ config = get_model_config_2b(**kwargs)
159
+ config.num_layers = 2
160
+ return config
161
+
162
+
163
+ def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
164
+ config = (
165
+ get_fake_model_config_2b_for_test(**kwargs)
166
+ if test_model
167
+ else get_model_config_2b(**kwargs)
168
+ )
169
+ model = Gemma(config)
170
+ if checkpoint_path is not None:
171
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
172
+ # since embedding and lm-head use the same weight, we need to set strict
173
+ # to False.
174
+ loader.load(model, strict=False)
175
+ model.eval()
176
+ return model
177
+
178
+
179
+ def define_and_run_2b(checkpoint_path, test_model=False) -> None:
180
+ kv_cache_max_len = 1024
181
+ model = build_2b_model(
182
+ checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
183
+ )
184
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
185
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
186
+ tokens[0, :4] = idx
187
+ input_pos = torch.arange(0, kv_cache_max_len)
188
+ kv = kv_utils.EKVCache.from_model_config(model.config)
189
+ print("running an inference")
190
+ print(model.forward(tokens, input_pos, kv))
191
+
192
+
193
+ if __name__ == "__main__":
194
+ checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
195
+ define_and_run_2b(checkpoint_path)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Note: This is an experimental version of phi2 with external KV cache.
17
+ # Please use with caution.
18
+
19
+ import os
20
+ from pathlib import Path
21
+
22
+ import torch
23
+
24
+ import ai_edge_torch
25
+ from ai_edge_torch.generative.examples.experimental.phi import phi2
26
+ from ai_edge_torch.generative.layers.experimental import ekv_cache
27
+ from ai_edge_torch.generative.quantize import quant_recipes
28
+
29
+
30
+ def convert_phi2_to_tflite(
31
+ checkpoint_path: str,
32
+ prefill_seq_len: int = 512,
33
+ kv_cache_max_len: int = 1024,
34
+ quantize: bool = True,
35
+ ):
36
+ """An example method for converting a Phi-2 model to multi-signature
37
+ tflite model.
38
+
39
+ Args:
40
+ checkpoint_path (str): The filepath to the model checkpoint, or
41
+ directory holding the checkpoint.
42
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
43
+ Defaults to 512.
44
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
45
+ including both prefill and decode. Defaults to 1024.
46
+ quantize (bool, optional): Whether the model should be quanized.
47
+ Defaults to True.
48
+ """
49
+ pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
50
+ # Tensors used to trace the model graph during conversion.
51
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
52
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
53
+ decode_token = torch.tensor([[0]], dtype=torch.long)
54
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
55
+ kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
56
+
57
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
58
+ edge_model = (
59
+ ai_edge_torch.signature(
60
+ 'prefill',
61
+ pytorch_model,
62
+ sample_kwargs={
63
+ 'tokens': prefill_tokens,
64
+ 'input_pos': prefill_input_pos,
65
+ 'kv_cache': kv,
66
+ },
67
+ )
68
+ .signature(
69
+ 'decode',
70
+ pytorch_model,
71
+ sample_kwargs={
72
+ 'tokens': decode_token,
73
+ 'input_pos': decode_input_pos,
74
+ 'kv_cache': kv,
75
+ },
76
+ )
77
+ .convert(quant_config=quant_config)
78
+ )
79
+ edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
80
+
81
+
82
+ if __name__ == '__main__':
83
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
84
+ convert_phi2_to_tflite(checkpoint_path)