ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -1,189 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Example of building phi-2 model from the Edge Generative API layers.
16
-
17
-
18
- import os
19
- from pathlib import Path
20
-
21
- from ai_edge_torch.generative.layers import attention
22
- from ai_edge_torch.generative.layers import builder
23
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
- import ai_edge_torch.generative.layers.model_config as cfg
25
- import ai_edge_torch.generative.utilities.loader as loading_utils
26
- import numpy as np
27
- import torch
28
- from torch import nn
29
-
30
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
- ff_up_proj="model.layers.{}.mlp.fc1",
32
- ff_down_proj="model.layers.{}.mlp.fc2",
33
- attn_query_proj="model.layers.{}.self_attn.q_proj",
34
- attn_key_proj="model.layers.{}.self_attn.k_proj",
35
- attn_value_proj="model.layers.{}.self_attn.v_proj",
36
- attn_output_proj="model.layers.{}.self_attn.dense",
37
- pre_attn_norm="model.layers.{}.input_layernorm",
38
- embedding="model.embed_tokens",
39
- final_norm="model.final_layernorm",
40
- lm_head="lm_head",
41
- )
42
-
43
-
44
- class Phi2(nn.Module):
45
- """A Phi-2 model built from the Edge Generative API layers."""
46
-
47
- def __init__(self, config: cfg.ModelConfig):
48
- super().__init__()
49
-
50
- self.config = config
51
- # Construct model layers.
52
- self.lm_head = nn.Linear(
53
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
54
- )
55
- self.tok_embedding = nn.Embedding(
56
- config.vocab_size, config.embedding_dim, padding_idx=0
57
- )
58
- self.transformer_blocks = nn.ModuleList(
59
- attention.TransformerBlock(config) for _ in range(config.num_layers)
60
- )
61
- self.final_norm = builder.build_norm(
62
- config.embedding_dim,
63
- config.final_norm_config,
64
- )
65
- self.rope_cache = attn_utils.build_rope_cache(
66
- size=config.kv_cache_max,
67
- dim=int(
68
- config.attn_config.rotary_percentage * config.attn_config.head_dim
69
- ),
70
- base=10_000,
71
- condense_ratio=1,
72
- dtype=torch.float32,
73
- device=torch.device("cpu"),
74
- )
75
- self.mask_cache = attn_utils.build_causal_mask_cache(
76
- size=config.kv_cache_max,
77
- dtype=torch.float32,
78
- device=torch.device("cpu"),
79
- )
80
- self.config = config
81
-
82
- # The model's forward function takes in additional k/v cache tensors
83
- # and returns the updated k/v cache tensors to the caller.
84
- # This can be eliminated if we handle k/v cache updates inside the model itself.
85
- @torch.inference_mode
86
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
87
- _, seq_len = idx.size()
88
- assert self.config.max_seq_len >= seq_len, (
89
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
90
- f" {self.config.max_seq_len}"
91
- )
92
-
93
- cos, sin = self.rope_cache
94
- cos = cos.index_select(0, input_pos)
95
- sin = sin.index_select(0, input_pos)
96
- mask = self.mask_cache.index_select(2, input_pos)
97
- mask = mask[:, :, :, : self.config.kv_cache_max]
98
-
99
- # forward the model itself
100
- x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
101
-
102
- for _, block in enumerate(self.transformer_blocks):
103
- x = block(x, (cos, sin), mask, input_pos)
104
-
105
- x = self.final_norm(x)
106
- res = self.lm_head(x) # (b, t, vocab_size)
107
- return res
108
-
109
-
110
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
111
- """Returns the model config for a Phi-2 model.
112
-
113
- Args:
114
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
115
- is 1024.
116
-
117
- Returns:
118
- The model config for a Phi-2 model.
119
- """
120
- attn_config = cfg.AttentionConfig(
121
- num_heads=32,
122
- head_dim=80,
123
- num_query_groups=32,
124
- rotary_percentage=0.4,
125
- qkv_use_bias=True,
126
- output_proj_use_bias=True,
127
- )
128
- ff_config = cfg.FeedForwardConfig(
129
- type=cfg.FeedForwardType.SEQUENTIAL,
130
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
131
- intermediate_size=10240,
132
- use_bias=True,
133
- )
134
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
135
- config = cfg.ModelConfig(
136
- vocab_size=51200,
137
- num_layers=32,
138
- max_seq_len=2048,
139
- kv_cache_max_len=kv_cache_max_len,
140
- embedding_dim=2560,
141
- attn_config=attn_config,
142
- ff_config=ff_config,
143
- pre_attention_norm_config=norm_config,
144
- final_norm_config=norm_config,
145
- parallel_residual=True,
146
- lm_head_use_bias=True,
147
- enable_hlfb=True,
148
- )
149
- return config
150
-
151
-
152
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
153
- config = get_model_config(kv_cache_max_len)
154
- config.vocab_size = 128
155
- config.num_layers = 2
156
- config.max_seq_len = 2 * kv_cache_max_len
157
- config.ff_config.intermediate_size = 128
158
- return config
159
-
160
-
161
- def build_model(checkpoint_path, **kwargs) -> nn.Module:
162
- config = get_model_config(**kwargs)
163
- model = Phi2(config)
164
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
165
- loader.load(model)
166
- return model
167
-
168
-
169
- def define_and_run() -> None:
170
- """Instantiates and runs a Phi-2 model."""
171
-
172
- current_dir = Path(__file__).parent.resolve()
173
- phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
174
- kv_cache_max_len = 1024
175
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
176
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
177
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
178
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
179
- tokens[0, :4] = idx
180
- input_pos = torch.arange(0, kv_cache_max_len)
181
- lm_logits = model.forward(tokens, input_pos)
182
- print("comparing with goldens..")
183
- assert torch.allclose(
184
- phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
185
- )
186
-
187
-
188
- if __name__ == "__main__":
189
- define_and_run()
@@ -1,176 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # A toy example which has basic transformer block (w/ externalized KV-Cache).
16
-
17
- from typing import Tuple
18
-
19
- import ai_edge_torch
20
- from ai_edge_torch import lowertools
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
24
- from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
25
- import ai_edge_torch.generative.layers.model_config as cfg
26
- import torch
27
- import torch.nn as nn
28
-
29
- RoPECache = Tuple[torch.Tensor, torch.Tensor]
30
-
31
-
32
- class ToyModelWithExternalKV(torch.nn.Module):
33
-
34
- def __init__(self, config: cfg.ModelConfig) -> None:
35
- super().__init__()
36
- self.lm_head = nn.Linear(
37
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
38
- )
39
- self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
40
- self.transformer_blocks = nn.ModuleList(
41
- TransformerBlock(config) for _ in range(config.num_layers)
42
- )
43
- self.final_norm = builder.build_norm(
44
- config.embedding_dim,
45
- config.final_norm_config,
46
- )
47
- self.rope_cache = attn_utils.build_rope_cache(
48
- size=config.max_seq_len,
49
- dim=int(
50
- config.attn_config.rotary_percentage * config.attn_config.head_dim
51
- ),
52
- base=10_000,
53
- condense_ratio=1,
54
- dtype=torch.float32,
55
- device=torch.device('cpu'),
56
- )
57
- self.mask_cache = attn_utils.build_causal_mask_cache(
58
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
59
- )
60
- self.config = config
61
-
62
- def forward(
63
- self,
64
- tokens: torch.Tensor,
65
- input_pos: torch.Tensor,
66
- kv_cache: kv_utils.EKVCache,
67
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
68
- x = self.tok_embedding(tokens)
69
- cos, sin = self.rope_cache
70
- cos = cos.index_select(0, input_pos)
71
- sin = sin.index_select(0, input_pos)
72
- mask = self.mask_cache.index_select(2, input_pos)
73
- mask = mask[:, :, :, : self.config.max_seq_len]
74
-
75
- updated_kv_entires = []
76
- for i, block in enumerate(self.transformer_blocks):
77
- kv_entry = kv_cache.caches[i] if kv_cache else None
78
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
- if kv_entry:
80
- updated_kv_entires.append(kv_entry)
81
-
82
- x = self.final_norm(x)
83
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
84
- return self.lm_head(x), updated_kv_cache
85
-
86
-
87
- def _export_stablehlo_mlir(model, args):
88
- ep = torch.export.export(model, args)
89
- return lowertools.exported_program_to_mlir_text(ep)
90
-
91
-
92
- def get_model_config() -> cfg.ModelConfig:
93
- attn_config = cfg.AttentionConfig(
94
- num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
95
- )
96
- ff_config = cfg.FeedForwardConfig(
97
- type=cfg.FeedForwardType.GATED,
98
- activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
99
- intermediate_size=256,
100
- )
101
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
102
- config = cfg.ModelConfig(
103
- vocab_size=150,
104
- num_layers=2,
105
- max_seq_len=100,
106
- embedding_dim=128,
107
- attn_config=attn_config,
108
- ff_config=ff_config,
109
- pre_attention_norm_config=norm_config,
110
- post_attention_norm_config=norm_config,
111
- final_norm_config=norm_config,
112
- enable_hlfb=True,
113
- )
114
- return config
115
-
116
-
117
- def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
118
- tokens = torch.unsqueeze(torch.arange(0, 100), 0)
119
- input_pos = torch.arange(0, 100)
120
- return tokens, input_pos
121
-
122
-
123
- def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
124
- tokens = torch.tensor([[1]], dtype=torch.long)
125
- input_pos = torch.tensor([10])
126
- return tokens, input_pos
127
-
128
-
129
- def define_and_run() -> None:
130
- dump_mlir = False
131
-
132
- config = get_model_config()
133
- model = ToyModelWithExternalKV(config)
134
- model.eval()
135
- print('running an inference')
136
- kv = kv_utils.EKVCache.from_model_config(config)
137
-
138
- tokens, input_pos = get_sample_prefill_inputs()
139
- decode_token, decode_input_pos = get_sample_decode_inputs()
140
- print(model.forward(tokens, input_pos, kv))
141
-
142
- if dump_mlir:
143
- mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
144
- with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
145
- f.write(mlir_text)
146
-
147
- # Convert model to tflite with 2 signatures (prefill + decode).
148
- # TODO(b/344014416): currently conversion will fail, because we generate int64 index
149
- # in dynamic update slice op.
150
- print('converting toy model to tflite with 2 signatures (prefill + decode)')
151
- edge_model = (
152
- ai_edge_torch.signature(
153
- 'prefill',
154
- model,
155
- sample_kwargs={
156
- 'tokens': tokens,
157
- 'input_pos': input_pos,
158
- 'kv_cache': kv,
159
- },
160
- )
161
- .signature(
162
- 'decode',
163
- model,
164
- sample_kwargs={
165
- 'tokens': decode_token,
166
- 'input_pos': decode_input_pos,
167
- 'kv_cache': kv,
168
- },
169
- )
170
- .convert()
171
- )
172
- edge_model.export('/tmp/toy_external_kv_cache.tflite')
173
-
174
-
175
- if __name__ == '__main__':
176
- define_and_run()