ai-edge-torch-nightly 0.3.0.dev20250124__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (19) hide show
  1. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  2. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  3. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  4. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  5. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  6. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  7. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  8. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  9. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  10. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  11. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  12. ai_edge_torch/generative/utilities/converter.py +15 -4
  13. ai_edge_torch/generative/utilities/model_builder.py +5 -3
  14. ai_edge_torch/version.py +1 -1
  15. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  16. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +19 -10
  17. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  18. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  19. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,80 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting DeepSeek R1 distilled models to tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/deepseek'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
+ '/tmp/',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'deepseek',
40
+ 'The prefix of the output tflite model name.',
41
+ )
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
62
+
63
+
64
+ def main(_):
65
+ pytorch_model = deepseek.build_model(
66
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67
+ )
68
+ converter.convert_to_tflite(
69
+ pytorch_model,
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
+ quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
75
+ export_config=ExportConfig(),
76
+ )
77
+
78
+
79
+ if __name__ == '__main__':
80
+ app.run(main)
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building DeepSeek R1 distilled models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
23
+
24
+
25
+ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
26
+ """A DeepSeek distilled model based on Qwen."""
27
+ pass
28
+
29
+
30
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Qwen 2.5 3B model.
32
+
33
+ Args:
34
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35
+ is 1024.
36
+
37
+ Returns:
38
+ The model config for a SmolLM model.
39
+ """
40
+ attn_config = cfg.AttentionConfig(
41
+ num_heads=12,
42
+ head_dim=128,
43
+ num_query_groups=2,
44
+ rotary_base=10000,
45
+ rotary_percentage=1.0,
46
+ qkv_use_bias=True,
47
+ )
48
+ ff_config = cfg.FeedForwardConfig(
49
+ type=cfg.FeedForwardType.GATED,
50
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51
+ intermediate_size=8960,
52
+ )
53
+ norm_config = cfg.NormalizationConfig(
54
+ type=cfg.NormalizationType.RMS_NORM,
55
+ epsilon=1e-06,
56
+ )
57
+ block_config = cfg.TransformerBlockConfig(
58
+ attn_config=attn_config,
59
+ ff_config=ff_config,
60
+ pre_attention_norm_config=norm_config,
61
+ post_attention_norm_config=norm_config,
62
+ )
63
+ config = cfg.ModelConfig(
64
+ vocab_size=151936,
65
+ num_layers=28,
66
+ max_seq_len=4096,
67
+ embedding_dim=1536,
68
+ kv_cache_max_len=kv_cache_max_len,
69
+ block_configs=block_config,
70
+ final_norm_config=norm_config,
71
+ lm_head_share_weight_with_embedding=False,
72
+ enable_hlfb=True,
73
+ )
74
+ return config
75
+
76
+
77
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78
+ config = get_model_config(**kwargs)
79
+ config.vocab_size = 128
80
+ config.num_layers = 2
81
+ # DeepSeek-R1-Distill-Qwen has only one block config.
82
+ config.block_config(0).ff_config.intermediate_size = 64
83
+ return config
84
+
85
+
86
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87
+ return model_builder.build_decoder_only_model(
88
+ checkpoint_path=checkpoint_path,
89
+ config=get_model_config(**kwargs),
90
+ tensor_names=TENSOR_NAMES,
91
+ model_class=DeepSeekDistillQwen,
92
+ )
@@ -0,0 +1,70 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = deepseek.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
+
57
+ verifier.verify_reauthored_model(
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
+ generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
65
+ atol=1e-04,
66
+ )
67
+
68
+
69
+ if __name__ == "__main__":
70
+ app.run(main)
@@ -85,6 +85,7 @@ def convert_stable_diffusion_to_tflite(
85
85
  clip.TENSOR_NAMES,
86
86
  )
87
87
  loader.load(clip_model, strict=False)
88
+ clip_model.eval()
88
89
 
89
90
  diffusion_model = diffusion.Diffusion(
90
91
  diffusion.get_model_config(batch_size=2, device_type=_DEVICE_TYPE.value)
@@ -93,6 +94,7 @@ def convert_stable_diffusion_to_tflite(
93
94
  diffusion_ckpt_path, diffusion.TENSOR_NAMES
94
95
  )
95
96
  diffusion_loader.load(diffusion_model, strict=False)
97
+ diffusion_model.eval()
96
98
 
97
99
  decoder_model = decoder.Decoder(
98
100
  decoder.get_model_config(device_type=_DEVICE_TYPE.value)
@@ -101,6 +103,7 @@ def convert_stable_diffusion_to_tflite(
101
103
  decoder_ckpt_path, decoder.TENSOR_NAMES
102
104
  )
103
105
  decoder_loader.load(decoder_model, strict=False)
106
+ decoder_model.eval()
104
107
 
105
108
  # TODO(yichunk): enable image encoder conversion
106
109
  # if encoder_ckpt_path is not None:
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
@@ -0,0 +1,269 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common building blocks for a GPU-specific Attention layer.
17
+
18
+ This is a temporary implemenation for the GPU. It is subject to change/removal
19
+ at any time.
20
+ """
21
+
22
+ from typing import Optional, Tuple, Union
23
+
24
+ from ai_edge_torch.generative.layers import builder
25
+ from ai_edge_torch.generative.layers import lora as lora_utils
26
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
27
+ from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
30
+ import torch
31
+ from torch import nn
32
+
33
+
34
+ class TransformerBlock(nn.Module):
35
+
36
+ def __init__(
37
+ self,
38
+ config: cfg.TransformerBlockConfig,
39
+ model_config: cfg.ModelConfig,
40
+ ) -> None:
41
+ """Initialize an instance of the TransformerBlock.
42
+
43
+ Args:
44
+ config (cfg.TransformerBlockConfig): the configuration object for this
45
+ transformer block.
46
+ model_config (cfg.ModelConfig): the configuration object for the model
47
+ this transformer block belongs to.
48
+ """
49
+ super().__init__()
50
+ self.pre_atten_norm = builder.build_norm(
51
+ model_config.embedding_dim,
52
+ config.pre_attention_norm_config,
53
+ )
54
+ self.atten_func = CausalSelfAttention(
55
+ model_config.batch_size,
56
+ model_config.embedding_dim,
57
+ config.attn_config,
58
+ model_config.enable_hlfb,
59
+ )
60
+ self.post_atten_norm = builder.build_norm(
61
+ model_config.embedding_dim,
62
+ config.post_attention_norm_config,
63
+ )
64
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
65
+ self.config = config
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
71
+ mask: Optional[torch.Tensor] = None,
72
+ input_pos: Optional[torch.Tensor] = None,
73
+ kv_cache: kv_utils.KVCacheEntryBase = None,
74
+ lora: Optional[lora_utils.LoRAEntry] = None,
75
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
76
+ """Forward function of the TransformerBlock.
77
+
78
+ Args:
79
+ x (torch.Tensor): the input tensor.
80
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
81
+ mask (torch.Tensor): the optional mask tensor.
82
+ input_pos (torch.Tensor): the optional input position tensor.
83
+ kv_cache (KVCacheEntryBase): the optional kv cache entry.
84
+ lora (LoRAEntry): the optional lora entry.
85
+
86
+ Returns:
87
+ output activation from this transformer block, and updated kv cache (if
88
+ passed in).
89
+ """
90
+ kv = None
91
+ if self.config.parallel_residual:
92
+ x_norm = self.pre_atten_norm(x)
93
+ atten_func_out = self.atten_func(
94
+ x_norm, rope, mask, input_pos, kv_cache, lora
95
+ )
96
+ if kv_cache is None:
97
+ attn_out = atten_func_out
98
+ else:
99
+ attn_out, kv = atten_func_out
100
+ ff_out = self.ff(x_norm)
101
+ output = x + attn_out + ff_out
102
+ else:
103
+ x_norm = self.pre_atten_norm(x)
104
+ atten_func_out = self.atten_func(
105
+ x_norm, rope, mask, input_pos, kv_cache, lora
106
+ )
107
+ if kv_cache is None:
108
+ attn_out = atten_func_out
109
+ else:
110
+ attn_out, kv = atten_func_out
111
+ x = x + attn_out
112
+ x_norm = self.post_atten_norm(x)
113
+ output = x + self.ff(x_norm)
114
+
115
+ return output if kv is None else (output, kv)
116
+
117
+
118
+ class CausalSelfAttention(nn.Module):
119
+
120
+ def __init__(
121
+ self,
122
+ batch_size: int,
123
+ dim: int,
124
+ config: cfg.AttentionConfig,
125
+ enable_hlfb: bool,
126
+ ) -> None:
127
+ """Initialize an instance of CausalSelfAttention.
128
+
129
+ Args:
130
+ batch_size (int): batch size of the input tensor.
131
+ dim (int): causal attention's input/output dimmension.
132
+ config (cfg.AttentionConfig): attention specific configurations.
133
+ enable_hlfb (bool): whether hlfb is enabled or not.
134
+ """
135
+ super().__init__()
136
+ self.kv_cache = None
137
+ self.batch_size = batch_size
138
+ qkv_shape = (
139
+ config.num_heads + 2 * config.num_query_groups
140
+ ) * config.head_dim
141
+ output_shape = config.num_heads * config.head_dim
142
+ # Key, query, value projections for all heads.
143
+ self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
144
+ self.output_projection = nn.Linear(
145
+ output_shape, dim, bias=config.output_proj_use_bias
146
+ )
147
+ self.query_norm = builder.build_norm(
148
+ config.head_dim, config.query_norm_config
149
+ )
150
+ self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
151
+ self.config = config
152
+ self.enable_hlfb = enable_hlfb
153
+ self.sdpa_func = sdpa.scaled_dot_product_attention
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
159
+ mask: Optional[torch.Tensor] = None,
160
+ input_pos: Optional[torch.Tensor] = None,
161
+ kv_cache: Optional[kv_utils.KVCacheEntryBase] = None,
162
+ lora: Optional[lora_utils.LoRAEntry] = None,
163
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
164
+ """Forward function of the CausalSelfAttention layer, which can support
165
+
166
+ MQA, GQA and MHA.
167
+
168
+ Args:
169
+ x (torch.Tensor): the input tensor.
170
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
171
+ mask (torch.Tensor): the optional mask tensor.
172
+ input_pos (torch.Tensor): the optional input position tensor.
173
+ kv_cache (KVCacheEntryBase): the KV cache entry corresponding to this
174
+ module.
175
+ lora (LoRAEntry): the optional lora entry.
176
+
177
+ Returns:
178
+ output activation from this self attention layer, and the updated
179
+ KV Cach Entry (if passed in).
180
+ """
181
+ # Batch size, sequence length, embedding dimensionality.
182
+ B, T, E = x.size()
183
+ assert B == self.batch_size, (
184
+ "batch size of input tensor must match with the batch size specified in"
185
+ " the model configuration."
186
+ )
187
+
188
+ qkv = self.qkv_projection(x)
189
+
190
+ # Assemble into a number of query groups to support MHA, MQA and GQA.
191
+ q_per_kv = self.config.num_heads // self.config.num_query_groups
192
+ # Each group has >=1 queries, 1 key, and 1 value.
193
+ if self.config.qkv_transpose_before_split:
194
+ qkv = qkv.view(B, T, -1, self.config.head_dim)
195
+ q, k, v = qkv.split(
196
+ (
197
+ q_per_kv * self.config.num_query_groups,
198
+ self.config.num_query_groups,
199
+ self.config.num_query_groups,
200
+ ),
201
+ dim=-2,
202
+ )
203
+ else:
204
+ qkv = qkv.view(B, T, self.config.num_query_groups, -1)
205
+ q, k, v = qkv.split(
206
+ (
207
+ q_per_kv * self.config.head_dim,
208
+ self.config.head_dim,
209
+ self.config.head_dim,
210
+ ),
211
+ dim=-1,
212
+ )
213
+
214
+ if lora is not None:
215
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
216
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
217
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
218
+
219
+ q = self.query_norm(q)
220
+ k = self.key_norm(k)
221
+
222
+ q = q.reshape(B, T, -1, self.config.head_dim)
223
+ k = k.reshape(B, T, -1, self.config.head_dim)
224
+ v = v.reshape(B, T, -1, self.config.head_dim)
225
+
226
+ if rope is not None:
227
+ # Compute rotary positional embedding for query and key.
228
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
229
+ cos, sin = rope
230
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
231
+
232
+ # Transpose k/v to specific layout for GPU implementation.
233
+ b, _, n, h = q.shape
234
+ g = n // self.config.num_query_groups
235
+ # btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
236
+ q = q.permute(0, 2, 1, 3).reshape(
237
+ 1, b * self.config.num_query_groups, g * T, h
238
+ )
239
+
240
+ k = k.permute(0, 2, 1, 3).reshape(
241
+ 1, -1, T, self.config.head_dim
242
+ ) # 1, bk, s, h
243
+ v = v.permute(0, 2, 3, 1).reshape(
244
+ 1, -1, self.config.head_dim, T
245
+ ) # 1, bk, h, s
246
+
247
+ if kv_cache is not None:
248
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
249
+ k, v = kv_cache.k_cache, kv_cache.v_cache
250
+
251
+ sdpa_out = self.sdpa_func(
252
+ kv_cache,
253
+ q,
254
+ k,
255
+ v,
256
+ self.config.head_dim,
257
+ mask=mask,
258
+ softcap=self.config.logit_softcap,
259
+ ) # 1, bk, gt, h
260
+ sdpa_out = (
261
+ sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
262
+ )
263
+
264
+ # Compute the output projection.
265
+ y = self.output_projection(sdpa_out)
266
+ if lora is not None:
267
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
268
+
269
+ return y if kv_cache is None else (y, kv_cache)
@@ -0,0 +1,314 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions for KV Cache.
17
+
18
+ This is an experimental implementation and is subject to change at any time.
19
+ """
20
+
21
+ import dataclasses
22
+ from typing import List, Tuple
23
+
24
+ from ai_edge_torch import hlfb
25
+ from ai_edge_torch.generative.layers import model_config
26
+ from ai_edge_torch.generative.layers.experimental import types as types
27
+ from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.utils._pytree as pytree
31
+
32
+ BATCH_SIZE = 1
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class KVCacheEntryBase:
37
+ """A single cache entry that includes K and V caches.
38
+
39
+ The chaches are built based on the provided config with the shape of
40
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
41
+ """
42
+
43
+ k_cache: torch.Tensor
44
+ v_cache: torch.Tensor
45
+
46
+ @classmethod
47
+ def _from_model_config(
48
+ cls,
49
+ kv_cache_max: int,
50
+ config: model_config.AttentionConfig,
51
+ k_shape: Tuple,
52
+ v_shape: Tuple,
53
+ dtype: torch.dtype = torch.float32,
54
+ device: torch.device = None,
55
+ ) -> "KVCacheEntryBase":
56
+ """Build an instance of the class based on model config."""
57
+ k = torch.zeros(k_shape, dtype=dtype, device=device)
58
+ v = torch.zeros(v_shape, dtype=dtype, device=device)
59
+ obj = cls(k_cache=k, v_cache=v)
60
+ return obj
61
+
62
+ @classmethod
63
+ def from_model_config(
64
+ cls,
65
+ kv_cache_max: int,
66
+ config: model_config.AttentionConfig,
67
+ dtype: torch.dtype = torch.float32,
68
+ device: torch.device = None,
69
+ ) -> "KVCacheEntryBase":
70
+ """Build an instance of the class based on model config."""
71
+ shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72
+ return cls._from_model_config(
73
+ kv_cache_max, config, shape, shape, dtype, device
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class KVCacheEntryBTNH(KVCacheEntryBase):
79
+ k_type = types.BTNH()
80
+ v_type = types.BTNH()
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class KVCacheEntryTransposed(KVCacheEntryBase):
85
+
86
+ k_type = types.BNTH()
87
+ v_type = types.BNHT()
88
+
89
+ @classmethod
90
+ def from_model_config(
91
+ cls,
92
+ kv_cache_max: int,
93
+ config: model_config.AttentionConfig,
94
+ dtype: torch.dtype = torch.float32,
95
+ device: torch.device = None,
96
+ ) -> "KVCacheEntryBase":
97
+ """Build an instance of the class based on model config."""
98
+ num_kv_heads = config.num_query_groups
99
+ k_shape = (
100
+ 1,
101
+ BATCH_SIZE * num_kv_heads,
102
+ kv_cache_max,
103
+ config.head_dim,
104
+ ) # 1, bk, s, h
105
+ v_shape = (
106
+ 1,
107
+ BATCH_SIZE * num_kv_heads,
108
+ config.head_dim,
109
+ kv_cache_max,
110
+ ) # 1, bk, h, s
111
+ return cls._from_model_config(
112
+ kv_cache_max, config, k_shape, v_shape, dtype, device
113
+ )
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class KVCacheBase:
118
+ """A utility class for holding KV cache entries per layer."""
119
+
120
+ caches: Tuple[KVCacheEntryBase, ...]
121
+
122
+ @classmethod
123
+ def _from_model_config(
124
+ cls,
125
+ kv_entry_cls,
126
+ config: model_config.ModelConfig,
127
+ dtype: torch.dtype = torch.float32,
128
+ device: torch.device = None,
129
+ ) -> "KVCacheBase":
130
+ caches = [
131
+ kv_entry_cls.from_model_config(
132
+ config.kv_cache_max,
133
+ config.block_config(idx).attn_config,
134
+ dtype,
135
+ device,
136
+ )
137
+ for idx in range(config.num_layers)
138
+ ]
139
+ obj = cls(caches=tuple(caches))
140
+ return obj
141
+
142
+ @classmethod
143
+ def from_model_config(
144
+ cls,
145
+ config: model_config.ModelConfig,
146
+ dtype: torch.dtype = torch.float32,
147
+ device: torch.device = None,
148
+ ) -> "KVCacheBase":
149
+ """Build an instance of the class based on model config.
150
+
151
+ Args:
152
+ config (ModelConfig): Model config used for building the cache.
153
+ dtype (torch.dtype, optional): The data type of the cache tensor.
154
+ Defaults to torch.float32.
155
+ device (torch.device, optional): The device placement of the cache
156
+ tensors. Defaults to None.
157
+
158
+ Returns:
159
+ KVCacheBase: The created cache object.
160
+ """
161
+ return cls._from_model_config(
162
+ KVCacheEntryBase, config=config, dtype=dtype, device=device
163
+ )
164
+
165
+ def flatten(self) -> List[torch.Tensor]:
166
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
167
+ flattened, _ = _flatten_kvc(self)
168
+ return flattened
169
+
170
+
171
+ @dataclasses.dataclass
172
+ class KVCacheBTNH(KVCacheBase):
173
+
174
+ @classmethod
175
+ def from_model_config(
176
+ cls,
177
+ config: model_config.ModelConfig,
178
+ dtype: torch.dtype = torch.float32,
179
+ device: torch.device = None,
180
+ ) -> "KVCacheBTNH":
181
+ return cls._from_model_config(
182
+ KVCacheEntryBTNH, config=config, dtype=dtype, device=device
183
+ )
184
+
185
+
186
+ @dataclasses.dataclass
187
+ class KVCacheTransposed(KVCacheBase):
188
+
189
+ @classmethod
190
+ def from_model_config(
191
+ cls,
192
+ config: model_config.ModelConfig,
193
+ dtype: torch.dtype = torch.float32,
194
+ device: torch.device = None,
195
+ ) -> "KVCacheBTNH":
196
+ return cls._from_model_config(
197
+ KVCacheEntryTransposed, config=config, dtype=dtype, device=device
198
+ )
199
+
200
+
201
+ def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
202
+ flattened = []
203
+ flat_names = []
204
+ none_names = []
205
+ for i, kv_entry in enumerate(kvc.caches):
206
+ flattened.append(kv_entry.k_cache)
207
+ flat_names.append(f"k_{i}")
208
+ flattened.append(kv_entry.v_cache)
209
+ flat_names.append(f"v_{i}")
210
+ return flattened, [flat_names, none_names]
211
+
212
+
213
+ def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
214
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
215
+ return [
216
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
217
+ ], flat_names
218
+
219
+
220
+ def _unflatten_kvc(
221
+ values: List[torch.Tensor], context: Tuple[List, List]
222
+ ) -> KVCacheBase:
223
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
224
+ num_layers = len(values) // 2
225
+ flat_names = context[0]
226
+ kv_entries = []
227
+ for i in range(num_layers):
228
+ k_cache_idx = flat_names.index(f"k_{i}")
229
+ v_cache_idx = flat_names.index(f"v_{i}")
230
+ kv_entries.append(
231
+ KVCacheEntryBase(
232
+ k_cache=values[k_cache_idx], v_cache=values[v_cache_idx]
233
+ )
234
+ )
235
+ obj = KVCacheBase(tuple(kv_entries))
236
+ return obj
237
+
238
+
239
+ pytree.register_pytree_node(
240
+ KVCacheTransposed,
241
+ _flatten_kvc,
242
+ _unflatten_kvc,
243
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
244
+ serialized_type_name="",
245
+ )
246
+
247
+ pytree.register_pytree_node(
248
+ KVCacheBase,
249
+ _flatten_kvc,
250
+ _unflatten_kvc,
251
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
252
+ serialized_type_name="",
253
+ )
254
+
255
+
256
+ def update(
257
+ cache: KVCacheEntryBase,
258
+ input_pos: torch.Tensor,
259
+ k_slice: torch.Tensor,
260
+ v_slice: torch.Tensor,
261
+ use_dus: bool = True,
262
+ ) -> KVCacheEntryBase:
263
+ """Out of place update of Cache buffer.
264
+
265
+ Args:
266
+ cache (KVCacheEntryBase): The original cache buffer.
267
+ input_pos (torch.Tensor): The update slice positions.
268
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
269
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
270
+
271
+ Returns:
272
+ KVCacheEntryBase: The updated KVCacheBase entry based on the passed
273
+ inputs.
274
+ """
275
+ update_kv_cache = _update_kv_impl
276
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
277
+
278
+
279
+ def _get_slice_indices(
280
+ positions: torch.Tensor, cache_dim: int, ts_idx: int
281
+ ) -> torch.Tensor:
282
+ """Returns the slice indices."""
283
+ positions = positions.float()[0].reshape(
284
+ 1,
285
+ )
286
+
287
+ zeros = torch.zeros((1,), dtype=torch.float32)
288
+ indices = []
289
+ for i in range(cache_dim):
290
+ if i == ts_idx:
291
+ indices.append(positions)
292
+ else:
293
+ indices.append(zeros)
294
+ slice_indices = torch.cat(indices, dim=0)
295
+ slice_indices = slice_indices.int()
296
+ return slice_indices
297
+
298
+
299
+ def _update_kv_impl(
300
+ cache: KVCacheEntryTransposed,
301
+ input_pos: torch.Tensor,
302
+ k_slice: torch.Tensor,
303
+ v_slice: torch.Tensor,
304
+ ) -> KVCacheEntryTransposed:
305
+ """Update the cache buffer with High Level Function Boundary annotation."""
306
+ cache_dim = 4
307
+ k_ts_idx = 2
308
+ v_ts_idx = 3
309
+ positions = input_pos.clone()
310
+ k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311
+ v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312
+ k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313
+ v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
314
+ return KVCacheEntryTransposed(k, v)
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Implements scaled dot product attention. This is experimental and
16
+ # GPU-specific code.
17
+
18
+ import math
19
+ from typing import Optional
20
+
21
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers.experimental import types
23
+ from ai_edge_torch.generative.utilities import bmm_4d as bmm_lib
24
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
+ from multipledispatch import dispatch
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+
30
+ def scaled_dot_product_attention(
31
+ kv: kv_utils.KVCacheBase,
32
+ query: torch.Tensor,
33
+ key: torch.Tensor,
34
+ value: torch.Tensor,
35
+ head_size: int,
36
+ mask: Optional[torch.Tensor] = None,
37
+ scale: Optional[float] = None,
38
+ softcap: Optional[float] = None,
39
+ ):
40
+ if hasattr(kv, "k_type") and hasattr(kv, "v_type"):
41
+ return _sdpa(
42
+ kv.k_type,
43
+ kv.v_type,
44
+ query=query,
45
+ key=key,
46
+ value=value,
47
+ head_size=head_size,
48
+ mask=mask,
49
+ scale=scale,
50
+ softcap=softcap,
51
+ )
52
+ raise ValueError(
53
+ f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
54
+ f" {type(kv.caches[0].v_type)} not supported."
55
+ )
56
+
57
+
58
+ @dispatch(types.BNTH, types.BNHT)
59
+ def _sdpa(k_type, v_type, *args, **kwargs):
60
+ query = kwargs["query"]
61
+ key = kwargs["key"]
62
+ value = kwargs["value"]
63
+ head_size = kwargs["head_size"]
64
+ mask = kwargs.get("mask", None)
65
+ scale = kwargs.get("scale", None)
66
+ softcap = kwargs.get("softcap", None)
67
+
68
+ if scale is None:
69
+ scale = 1.0 / math.sqrt(head_size)
70
+
71
+ query = query * scale
72
+
73
+ assert mask is not None, "Mask should not be None!"
74
+ t = mask.shape[2]
75
+
76
+ logits = bmm_lib.bmm_4d(query, key)
77
+
78
+ _, bk, gt, s = logits.shape
79
+ g = gt // t
80
+ logits = logits.reshape((bk, g, t, s))
81
+ if softcap is not None:
82
+ logits = torch.tanh(logits / softcap)
83
+ logits = logits * softcap
84
+
85
+ padded_logits = logits + mask
86
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
87
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
88
+
89
+ encoded = bmm_lib.bmm_4d(probs, value)
90
+
91
+ return encoded # 1, bk, gt, h
92
+
93
+
94
+ @dispatch(object, object)
95
+ def _sdpa(k_type, v_type, *args, **kwargs):
96
+
97
+ raise ValueError(f"No implementations for k={k_type} and v={v_type}")
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A listing of types describes the K and V tensors in KV caches.
16
+
17
+ import enum
18
+ from enum import Enum, auto
19
+ from typing import Tuple
20
+ from torch import nn
21
+
22
+
23
+ @enum.unique
24
+ class TensorDims(Enum):
25
+ BATCH = enum.auto()
26
+ SEQUENCE = enum.auto()
27
+ NUM_HEADS = enum.auto()
28
+ HEAD_DIM = enum.auto()
29
+ MODEL_DIM = enum.auto() # often num_heads * head_dim
30
+
31
+
32
+ DIM_TO_LETTER = {
33
+ TensorDims.BATCH: 'B',
34
+ TensorDims.SEQUENCE: 'T',
35
+ TensorDims.NUM_HEADS: 'N',
36
+ TensorDims.HEAD_DIM: 'H',
37
+ TensorDims.MODEL_DIM: 'D',
38
+ }
39
+
40
+
41
+ class TensorDimensionMeta(type):
42
+ """Metaclass to create classes representing an order of tensor dimensions."""
43
+
44
+ def __new__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
45
+ """Creates a new class with the given name and tensor dimension order.
46
+
47
+ Args:
48
+ name: Name of the new class.
49
+ bases: Base classes for the new class.
50
+ attrs: Attributes for the new class.
51
+ dimensions: A tuple of TensorDims defining the order.
52
+ """
53
+
54
+ attrs['dimensions'] = (
55
+ dimensions # Store the dimensions as a class attribute
56
+ )
57
+ return super().__new__(cls, name, bases, attrs)
58
+
59
+ def __init__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
60
+ super().__init__(name, bases, attrs)
61
+
62
+ def __repr__(cls):
63
+ return f'{cls.__name__}'
64
+
65
+
66
+ def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
67
+ """Creates a TensorDimensionMeta class with the specified dimensions.
68
+
69
+ Args:
70
+ dimensions: A tuple of TensorDims.
71
+
72
+ Returns:
73
+ A new class representing the tensor dimension order.
74
+ """
75
+ name = ''.join(DIM_TO_LETTER[d] for d in dims)
76
+ # Derive from nn.Module for torch tracing compatiblity.
77
+ return TensorDimensionMeta(name, (nn.Module,), {}, dimensions=dims)
78
+
79
+
80
+ BTNH = create_tensor_dimension_order_class((
81
+ TensorDims.BATCH,
82
+ TensorDims.SEQUENCE,
83
+ TensorDims.NUM_HEADS,
84
+ TensorDims.HEAD_DIM,
85
+ ))
86
+ BNTH = create_tensor_dimension_order_class((
87
+ TensorDims.BATCH,
88
+ TensorDims.NUM_HEADS,
89
+ TensorDims.SEQUENCE,
90
+ TensorDims.HEAD_DIM,
91
+ ))
92
+ BNHT = create_tensor_dimension_order_class((
93
+ TensorDims.BATCH,
94
+ TensorDims.NUM_HEADS,
95
+ TensorDims.HEAD_DIM,
96
+ TensorDims.SEQUENCE,
97
+ ))
@@ -17,6 +17,7 @@
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
+ from ai_edge_torch.generative.examples.deepseek import deepseek
20
21
  from ai_edge_torch.generative.examples.gemma import gemma1
21
22
  from ai_edge_torch.generative.examples.gemma import gemma2
22
23
  from ai_edge_torch.generative.examples.llama import llama
@@ -150,16 +151,15 @@ class TestModelConversion(googletest.TestCase):
150
151
  ai_edge_torch.config.in_oss,
151
152
  reason="tests with custom ops are not supported in oss",
152
153
  )
153
-
154
154
  def test_smollm2(self):
155
155
  config = smollm.get_fake_model_config_v2()
156
156
  pytorch_model = smollm.SmolLM2(config).eval()
157
157
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+
158
159
  @googletest.skipIf(
159
160
  ai_edge_torch.config.in_oss,
160
161
  reason="tests with custom ops are not supported in oss",
161
162
  )
162
-
163
163
  def test_openelm(self):
164
164
  config = openelm.get_fake_model_config()
165
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -174,6 +174,15 @@ class TestModelConversion(googletest.TestCase):
174
174
  pytorch_model = qwen.Qwen(config).eval()
175
175
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
176
176
 
177
+ @googletest.skipIf(
178
+ ai_edge_torch.config.in_oss,
179
+ reason="tests with custom ops are not supported in oss",
180
+ )
181
+ def test_deepseek(self):
182
+ config = deepseek.get_fake_model_config()
183
+ pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
184
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
185
+
177
186
  @googletest.skipIf(
178
187
  ai_edge_torch.config.in_oss,
179
188
  reason="tests with custom ops are not supported in oss",
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Optional, Union
20
20
  from ai_edge_torch._convert import converter as converter_utils
21
21
  from ai_edge_torch.generative.layers import lora as lora_utils
22
- import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
22
  import ai_edge_torch.generative.layers.model_config as cfg
24
23
  from ai_edge_torch.generative.quantize import quant_recipes
25
24
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -151,9 +150,21 @@ def _export_helper(
151
150
  else None
152
151
  )
153
152
 
153
+ if export_config.prefill_mask is None:
154
+ prefill_masks = None
155
+ elif isinstance(export_config.prefill_mask, torch.Tensor):
156
+ prefill_masks = [export_config.prefill_mask]
157
+ elif isinstance(export_config.prefill_mask, list):
158
+ prefill_masks = export_config.prefill_mask
159
+ else:
160
+ raise ValueError('Prefill masks unrecognized.')
161
+
162
+ if prefill_masks:
163
+ assert len(prefill_masks) == len(prefill_seq_lens)
164
+
154
165
  decode_token = torch.tensor([[0]], dtype=torch.int)
155
166
  decode_input_pos = torch.tensor([0], dtype=torch.int)
156
- kv = kv_utils.KVCache.from_model_config(config)
167
+ kv = export_config.kvcache_cls.from_model_config(config)
157
168
 
158
169
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
159
170
 
@@ -174,8 +185,8 @@ def _export_helper(
174
185
  'input_pos': prefill_input_pos,
175
186
  'kv_cache': kv,
176
187
  }
177
- if export_config.prefill_mask is not None:
178
- sample_kwargs['mask'] = export_config.prefill_mask
188
+ if prefill_masks is not None:
189
+ sample_kwargs['mask'] = prefill_masks[i]
179
190
 
180
191
  if lora is not None:
181
192
  prefill_signature_name += f'_lora_r{lora.get_rank()}'
@@ -17,7 +17,7 @@
17
17
 
18
18
  import copy
19
19
  from dataclasses import dataclass
20
- from typing import Optional, Tuple
20
+ from typing import List, Optional, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.layers import attention
23
23
  from ai_edge_torch.generative.layers import builder
@@ -56,8 +56,10 @@ class ExportConfig:
56
56
  # When False, only decode signatures will produce output.
57
57
  output_logits_on_prefill: bool = False
58
58
  # Attention masks given as inputs to the model.
59
- prefill_mask: Optional[torch.Tensor] = None
60
- decode_mask: Optional[torch.Tensor] = None
59
+ prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
60
+ decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
61
+ # The KV Cache class for K and V buffers in attention.
62
+ kvcache_cls: type = kv_utils.KVCache
61
63
 
62
64
 
63
65
  class DecoderOnlyModel(nn.Module):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250124"
16
+ __version__ = "0.3.0.dev20250125"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250124
3
+ Version: 0.3.0.dev20250125
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=xloGd_dX0MD8k-quT07WLlEN1zIGVtCKu6xBSvjofrc,706
5
+ ai_edge_torch/version.py,sha256=yuz53SwRvngiQ41D-VX7MPmVGe-Vi-UR3v12E-o3P4I,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -49,6 +49,10 @@ ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0H
49
49
  ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
50
50
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
51
51
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
52
+ ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
53
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=I5eA-XfFdHjYwDsLIjn23T2e-IgnSCQ129-5DOU8j44,2532
54
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
55
+ ai_edge_torch/generative/examples/deepseek/verify.py,sha256=sDYBhmE_CeZw5iLIQ7rJNGLjhcTyKUQGdg7_QQBh9WM,2398
52
56
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
54
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
@@ -97,7 +101,7 @@ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68
97
101
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
98
102
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
99
103
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
100
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=sB_7-PVri8PxKnFG7c8GsTGyrxGEda-oZwGyyScTL3o,5239
104
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=GtwKAByEk0ENGEWbUmC2mAAPkbLZ3M5xH1HIToyu8QE,5307
101
105
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=sQKQ-k6H9kG2brgwLsktjCMeN2h0POyfMP6iNsPNKWc,16271
102
106
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6W58LxmHHkz2ctgpknQkyoDANZAnE9Byp_svfqLpQf0,34793
103
107
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
@@ -134,6 +138,11 @@ ai_edge_torch/generative/layers/model_config.py,sha256=ZVRWEGw1BnLbLCuoR71kWGqQt
134
138
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
135
139
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
136
140
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
141
+ ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
142
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
143
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
144
+ ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
145
+ ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
137
146
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
138
147
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
139
148
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -151,15 +160,15 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
151
160
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
152
161
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
153
162
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
154
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF_xz0UR3kD3x74ELsmJetbQnmv7-9gyQ,12473
163
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=AJs_ARfWUqwuFRwYtQQOLd87CiD4mUDwAhq885cqc4Q,12875
155
164
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
156
165
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
157
166
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
158
167
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
159
- ai_edge_torch/generative/utilities/converter.py,sha256=QIYxT-zATMzsD3LG-keRkxpJqDKXkbil4Se1KXthWFg,7726
168
+ ai_edge_torch/generative/utilities/converter.py,sha256=6siSpCvH_cLV-eP40lkF_AqjBpYv68xeMRQ722fKgE0,8065
160
169
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
161
170
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
162
- ai_edge_torch/generative/utilities/model_builder.py,sha256=aXigoFEMLAKk7HQuWJM5ILs3igA4z2VH64ZCzCuBhDE,6671
171
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
163
172
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
164
173
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
165
174
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -213,8 +222,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
213
222
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
214
223
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
215
224
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
216
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
217
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/METADATA,sha256=WKVqBXJtXvMv3JfqtYKcl1GFgKrKtSbZ-tJAol5PPHk,1966
218
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
219
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
220
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/RECORD,,
225
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/METADATA,sha256=BkUH2iAinJYGmBLTMdeYSpihXAHY_mBOkeprZLPaDGk,1966
227
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/RECORD,,