ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import logging
17
17
  import os
18
18
  from typing import Any, Optional
19
19
 
20
+ from ai_edge_torch import fx_pass_base
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch import model
22
23
  from ai_edge_torch._convert import fx_passes
@@ -34,7 +35,7 @@ def _run_convert_passes(
34
35
  exported_program = generative_fx_passes.run_generative_passes(
35
36
  exported_program
36
37
  )
37
- return fx_passes.run_passes(
38
+ return fx_pass_base.run_passes(
38
39
  exported_program,
39
40
  [
40
41
  fx_passes.BuildInterpolateCompositePass(),
@@ -15,44 +15,8 @@
15
15
 
16
16
  from typing import Sequence, Union
17
17
 
18
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
20
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassBase
21
- from ai_edge_torch._convert.fx_passes._pass_base import FxPassResult
22
- from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
23
- from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
24
- from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
25
- from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
26
- from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
27
- from torch.export import ExportedProgram
28
- from torch.fx.passes.infra.pass_manager import pass_result_wrapper
29
- import torch.utils._pytree as pytree
30
-
31
-
32
- # TODO(cnchan): make a PassManager class.
33
- def run_passes(
34
- exported_program: ExportedProgram,
35
- passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
36
- ) -> ExportedProgram:
37
- passes, _ = pytree.tree_flatten(passes)
38
- for pass_ in passes:
39
- if not isinstance(pass_, ExportedProgramPassBase):
40
- pass_ = pass_result_wrapper(pass_)
41
- if isinstance(pass_, ExportedProgramPassBase):
42
- exported_program = pass_(exported_program).exported_program
43
- else:
44
- gm = exported_program.graph_module
45
- gm, modified = pass_(gm)
46
- if modified and gm is not exported_program.graph_module:
47
- exported_program = ExportedProgram(
48
- root=gm,
49
- graph=gm.graph,
50
- graph_signature=exported_program.graph_signature,
51
- state_dict=exported_program.state_dict,
52
- range_constraints=exported_program.range_constraints,
53
- module_call_graph=exported_program.module_call_graph,
54
- example_inputs=exported_program.example_inputs,
55
- verifier=exported_program.verifier,
56
- constants=exported_program.constants,
57
- )
58
- return exported_program
18
+ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
+ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
+ from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
+ from ai_edge_torch.fx_pass_base import CanonicalizePass
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from functools import reduce
17
16
  from typing import Any, Callable
17
+ from ai_edge_torch import fx_pass_base
18
18
  from ai_edge_torch import lowertools
19
19
  import torch
20
- from torch.fx.passes.infra import pass_base
21
20
  import torch.utils._pytree as pytree
22
21
 
23
22
  _composite_builders: dict[
@@ -277,7 +276,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
277
276
  node.target = embedding
278
277
 
279
278
 
280
- class BuildAtenCompositePass(pass_base.PassBase):
279
+ class BuildAtenCompositePass(fx_pass_base.PassBase):
281
280
 
282
281
  def call(self, graph_module: torch.fx.GraphModule):
283
282
  for node in graph_module.graph.nodes:
@@ -286,4 +285,4 @@ class BuildAtenCompositePass(pass_base.PassBase):
286
285
 
287
286
  graph_module.graph.lint()
288
287
  graph_module.recompile()
289
- return pass_base.PassResult(graph_module, True)
288
+ return fx_pass_base.PassResult(graph_module, True)
@@ -16,8 +16,7 @@
16
16
 
17
17
  import functools
18
18
 
19
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
20
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
19
+ from ai_edge_torch import fx_pass_base
21
20
  from ai_edge_torch.hlfb import mark_pattern
22
21
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
23
22
  import torch
@@ -103,7 +102,7 @@ def _get_interpolate_nearest2d_pattern():
103
102
  return pattern
104
103
 
105
104
 
106
- class BuildInterpolateCompositePass(ExportedProgramPassBase):
105
+ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
107
106
 
108
107
  def __init__(self):
109
108
  super().__init__()
@@ -124,4 +123,4 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
124
123
 
125
124
  graph_module.graph.lint()
126
125
  graph_module.recompile()
127
- return ExportedProgramPassResult(exported_program, True)
126
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch import fx_pass_base
16
17
  from ai_edge_torch import lowertools
17
18
  import torch
18
- from torch.fx.passes.infra.pass_base import PassBase
19
- from torch.fx.passes.infra.pass_base import PassResult
20
19
  import torch.utils._pytree as pytree
21
20
 
22
21
 
@@ -62,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
61
  node.target = debuginfo_writer
63
62
 
64
63
 
65
- class InjectMlirDebuginfoPass(PassBase):
64
+ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
66
65
 
67
66
  def call(self, graph_module: torch.fx.GraphModule):
68
67
  for node in graph_module.graph.nodes:
@@ -70,4 +69,4 @@ class InjectMlirDebuginfoPass(PassBase):
70
69
 
71
70
  graph_module.graph.lint()
72
71
  graph_module.recompile()
73
- return PassResult(graph_module, True)
72
+ return fx_pass_base.PassResult(graph_module, True)
@@ -18,8 +18,7 @@ import operator
18
18
  import os
19
19
  from typing import Union
20
20
 
21
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassBase
22
- from ai_edge_torch._convert.fx_passes import ExportedProgramPassResult
21
+ from ai_edge_torch import fx_pass_base
23
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
24
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
25
24
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
@@ -31,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
31
30
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
32
31
 
33
32
 
34
- class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
33
+ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
35
34
 
36
35
  def get_source_meta(self, node: torch.fx.Node):
37
36
  keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
@@ -94,7 +93,7 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
94
93
 
95
94
  q_args = input_q.args[1:]
96
95
  q_kwargs = input_q.kwargs
97
- q_op, dq_op = self.get_paired_q_dq_ops(input_q.target)
96
+ q_op, dq_op = utils.get_paired_q_dq_ops(input_q.target)
98
97
  with graph.inserting_before(target):
99
98
  # Q and DQ inserted here may required updating the `axis` arg when they
100
99
  # are per_channel ops. However, instead of updating here, the nodes would
@@ -301,4 +300,4 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
301
300
  # Mark const node again for debugging
302
301
  self.mark_const_nodes(exported_program)
303
302
 
304
- return ExportedProgramPassResult(exported_program, True)
303
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
ai_edge_torch/config.py CHANGED
@@ -21,4 +21,7 @@ import os
21
21
 
22
22
  @dataclasses.dataclass
23
23
  class Config:
24
- use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "True") == "True"
24
+ use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
25
+ "1",
26
+ "true",
27
+ )
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import abc
17
+ import collections
18
+ from typing import Sequence, Union
19
+
20
+ import torch
21
+ from torch.fx.passes.infra.pass_base import PassBase
22
+ from torch.fx.passes.infra.pass_base import PassResult
23
+ from torch.fx.passes.infra.pass_manager import pass_result_wrapper
24
+ import torch.utils._pytree as pytree
25
+
26
+ FxPassBase = PassBase
27
+ FxPassResult = PassResult
28
+ ExportedProgramPassResult = collections.namedtuple(
29
+ "ExportedProgramPassResult", ["exported_program", "modified"]
30
+ )
31
+
32
+
33
+ class ExportedProgramPassBase(abc.ABC):
34
+
35
+ def __call__(
36
+ self, exported_program: torch.export.ExportedProgram
37
+ ) -> ExportedProgramPassResult:
38
+ self.requires(exported_program)
39
+ res = self.call(exported_program)
40
+ self.ensures(exported_program)
41
+ return res
42
+
43
+ @abc.abstractmethod
44
+ def call(
45
+ self, exported_program: torch.export.ExportedProgram
46
+ ) -> ExportedProgramPassResult:
47
+ pass
48
+
49
+ def requires(self, exported_program: torch.export.ExportedProgram) -> None:
50
+ pass
51
+
52
+ def ensures(self, exported_program: torch.export.ExportedProgram) -> None:
53
+ pass
54
+
55
+
56
+ # TODO(cnchan): make a PassManager class.
57
+ def run_passes(
58
+ exported_program: torch.export.ExportedProgram,
59
+ passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
60
+ ) -> torch.export.ExportedProgram:
61
+ passes, _ = pytree.tree_flatten(passes)
62
+ for pass_ in passes:
63
+ if not isinstance(pass_, ExportedProgramPassBase):
64
+ pass_ = pass_result_wrapper(pass_)
65
+ if isinstance(pass_, ExportedProgramPassBase):
66
+ exported_program = pass_(exported_program).exported_program
67
+ else:
68
+ gm = exported_program.graph_module
69
+ gm, modified = pass_(gm)
70
+ if modified and gm is not exported_program.graph_module:
71
+ exported_program = torch.export.ExportedProgram(
72
+ root=gm,
73
+ graph=gm.graph,
74
+ graph_signature=exported_program.graph_signature,
75
+ state_dict=exported_program.state_dict,
76
+ range_constraints=exported_program.range_constraints,
77
+ module_call_graph=exported_program.module_call_graph,
78
+ example_inputs=exported_program.example_inputs,
79
+ verifier=exported_program.verifier,
80
+ constants=exported_program.constants,
81
+ )
82
+ return exported_program
83
+
84
+
85
+ class CanonicalizePass(ExportedProgramPassBase):
86
+
87
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
88
+ # any op decompositions but just aot_export_module. Due to the check in
89
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
90
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
91
+ # dummy decomp table is needed.
92
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
93
+ _DUMMY_DECOMP_TABLE = {
94
+ torch._ops.OperatorBase(): lambda: None,
95
+ }
96
+
97
+ def call(self, exported_program: torch.export.ExportedProgram):
98
+ exported_program = exported_program.run_decompositions(
99
+ self._DUMMY_DECOMP_TABLE
100
+ )
101
+ return ExportedProgramPassResult(exported_program, True)
@@ -13,55 +13,74 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
48
  )
46
49
  # Tensors used to trace the model graph during conversion.
47
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
- prefill_input_pos = torch.arange(0, prefill_seq_len)
49
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -44,24 +47,40 @@ def convert_gemma_to_tflite(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
48
  )
46
49
  # Tensors used to trace the model graph during conversion.
47
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
- prefill_input_pos = torch.arange(0, prefill_seq_len)
49
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.tok_embedding = nn.Embedding(
54
55
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
60
61
  )
61
62
  # Gemma re-uses the embedding as the head projection layer.
62
63
  self.lm_head.weight.data = self.tok_embedding.weight.data
64
+ # Gemma has only one block config.
65
+ block_config = config.block_config(0)
63
66
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
67
+ attention.TransformerBlock(block_config, config)
68
+ for _ in range(config.num_layers)
65
69
  )
66
70
  self.final_norm = builder.build_norm(
67
71
  config.embedding_dim,
68
72
  config.final_norm_config,
69
73
  )
74
+ attn_config = block_config.attn_config
70
75
  self.rope_cache = attn_utils.build_rope_cache(
71
76
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
77
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
78
  base=10_000,
76
79
  condense_ratio=1,
77
80
  dtype=torch.float32,
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
84
87
  )
85
88
  self.config = config
86
89
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
90
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
91
+ def forward(
92
+ self,
93
+ tokens: torch.Tensor,
94
+ input_pos: torch.Tensor,
95
+ kv_cache: kv_utils.KVCache,
96
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
+ _, seq_len = tokens.size()
93
98
  assert self.config.max_seq_len >= seq_len, (
94
99
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
100
  f" {self.config.max_seq_len}"
96
101
  )
102
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
103
+ "The number of transformer blocks and the number of KV cache entries"
104
+ " must be the same."
105
+ )
97
106
 
98
107
  cos, sin = self.rope_cache
99
108
  cos = cos.index_select(0, input_pos)
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
102
111
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
112
 
104
113
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
114
+ x = self.tok_embedding(tokens)
106
115
  x = x * (self.config.embedding_dim**0.5)
107
116
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
124
 
111
125
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
128
 
115
129
 
116
130
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
139
153
  epsilon=1e-6,
140
154
  zero_centered=True,
141
155
  )
156
+ block_config = cfg.TransformerBlockConfig(
157
+ attn_config=attn_config,
158
+ ff_config=ff_config,
159
+ pre_attention_norm_config=norm_config,
160
+ post_attention_norm_config=norm_config,
161
+ )
142
162
  config = cfg.ModelConfig(
143
163
  vocab_size=256000,
144
164
  num_layers=18,
145
165
  max_seq_len=8192,
146
166
  embedding_dim=2048,
147
167
  kv_cache_max_len=kv_cache_max_len,
148
- attn_config=attn_config,
149
- ff_config=ff_config,
150
- pre_attention_norm_config=norm_config,
151
- post_attention_norm_config=norm_config,
168
+ block_configs=block_config,
152
169
  final_norm_config=norm_config,
153
- parallel_residual=False,
154
170
  lm_head_use_bias=False,
155
171
  enable_hlfb=True,
156
172
  )
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
159
175
 
160
176
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
161
177
  config = get_model_config_2b(kv_cache_max_len)
162
- config.ff_config.intermediate_size = 128
178
+ # Gemma has only one block config.
179
+ config.block_config(0).ff_config.intermediate_size = 128
163
180
  config.vocab_size = 128
164
181
  config.num_layers = 2
165
182
  config.max_seq_len = 2 * kv_cache_max_len
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
170
187
  config = get_model_config_2b(**kwargs)
171
188
  model = Gemma(config)
172
189
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
173
- # since embedding and lm-head use the same weight, we need to set strict
190
+ # Since embedding and lm-head use the same weight, we need to set strict
174
191
  # to False.
175
192
  loader.load(model, strict=False)
176
193
  model.eval()
177
194
  return model
178
195
 
179
196
 
180
- def define_and_run_2b() -> None:
197
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
198
  """Instantiates and runs a Gemma 2B model."""
182
199
 
183
- current_dir = Path(__file__).parent.resolve()
200
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
201
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
202
 
186
203
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
204
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
205
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
206
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
191
207
  tokens[0, :4] = idx
192
- input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
208
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
209
+ kv = kv_utils.KVCache.from_model_config(model.config)
210
+ output = model.forward(tokens, input_pos, kv)
194
211
  print("comparing with goldens..")
195
212
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
213
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
214
  )
198
215
 
199
216
 
200
217
  if __name__ == "__main__":
201
- define_and_run_2b()
218
+ input_checkpoint_path = os.path.join(
219
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
+ )
221
+ define_and_run_2b(input_checkpoint_path)