ai-edge-torch-nightly 0.3.0.dev20250124__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  2. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  3. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  4. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  5. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  6. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  7. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  8. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  9. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  10. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  11. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  12. ai_edge_torch/generative/utilities/converter.py +15 -4
  13. ai_edge_torch/generative/utilities/model_builder.py +5 -3
  14. ai_edge_torch/version.py +1 -1
  15. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  16. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +19 -10
  17. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  18. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  19. {ai_edge_torch_nightly-0.3.0.dev20250124.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,80 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting DeepSeek R1 distilled models to tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/deepseek'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
+ '/tmp/',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'deepseek',
40
+ 'The prefix of the output tflite model name.',
41
+ )
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
62
+
63
+
64
+ def main(_):
65
+ pytorch_model = deepseek.build_model(
66
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67
+ )
68
+ converter.convert_to_tflite(
69
+ pytorch_model,
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
+ quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
75
+ export_config=ExportConfig(),
76
+ )
77
+
78
+
79
+ if __name__ == '__main__':
80
+ app.run(main)
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building DeepSeek R1 distilled models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
23
+
24
+
25
+ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
26
+ """A DeepSeek distilled model based on Qwen."""
27
+ pass
28
+
29
+
30
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Qwen 2.5 3B model.
32
+
33
+ Args:
34
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35
+ is 1024.
36
+
37
+ Returns:
38
+ The model config for a SmolLM model.
39
+ """
40
+ attn_config = cfg.AttentionConfig(
41
+ num_heads=12,
42
+ head_dim=128,
43
+ num_query_groups=2,
44
+ rotary_base=10000,
45
+ rotary_percentage=1.0,
46
+ qkv_use_bias=True,
47
+ )
48
+ ff_config = cfg.FeedForwardConfig(
49
+ type=cfg.FeedForwardType.GATED,
50
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51
+ intermediate_size=8960,
52
+ )
53
+ norm_config = cfg.NormalizationConfig(
54
+ type=cfg.NormalizationType.RMS_NORM,
55
+ epsilon=1e-06,
56
+ )
57
+ block_config = cfg.TransformerBlockConfig(
58
+ attn_config=attn_config,
59
+ ff_config=ff_config,
60
+ pre_attention_norm_config=norm_config,
61
+ post_attention_norm_config=norm_config,
62
+ )
63
+ config = cfg.ModelConfig(
64
+ vocab_size=151936,
65
+ num_layers=28,
66
+ max_seq_len=4096,
67
+ embedding_dim=1536,
68
+ kv_cache_max_len=kv_cache_max_len,
69
+ block_configs=block_config,
70
+ final_norm_config=norm_config,
71
+ lm_head_share_weight_with_embedding=False,
72
+ enable_hlfb=True,
73
+ )
74
+ return config
75
+
76
+
77
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78
+ config = get_model_config(**kwargs)
79
+ config.vocab_size = 128
80
+ config.num_layers = 2
81
+ # DeepSeek-R1-Distill-Qwen has only one block config.
82
+ config.block_config(0).ff_config.intermediate_size = 64
83
+ return config
84
+
85
+
86
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87
+ return model_builder.build_decoder_only_model(
88
+ checkpoint_path=checkpoint_path,
89
+ config=get_model_config(**kwargs),
90
+ tensor_names=TENSOR_NAMES,
91
+ model_class=DeepSeekDistillQwen,
92
+ )
@@ -0,0 +1,70 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.deepseek import deepseek
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = deepseek.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
+
57
+ verifier.verify_reauthored_model(
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
+ generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
65
+ atol=1e-04,
66
+ )
67
+
68
+
69
+ if __name__ == "__main__":
70
+ app.run(main)
@@ -85,6 +85,7 @@ def convert_stable_diffusion_to_tflite(
85
85
  clip.TENSOR_NAMES,
86
86
  )
87
87
  loader.load(clip_model, strict=False)
88
+ clip_model.eval()
88
89
 
89
90
  diffusion_model = diffusion.Diffusion(
90
91
  diffusion.get_model_config(batch_size=2, device_type=_DEVICE_TYPE.value)
@@ -93,6 +94,7 @@ def convert_stable_diffusion_to_tflite(
93
94
  diffusion_ckpt_path, diffusion.TENSOR_NAMES
94
95
  )
95
96
  diffusion_loader.load(diffusion_model, strict=False)
97
+ diffusion_model.eval()
96
98
 
97
99
  decoder_model = decoder.Decoder(
98
100
  decoder.get_model_config(device_type=_DEVICE_TYPE.value)
@@ -101,6 +103,7 @@ def convert_stable_diffusion_to_tflite(
101
103
  decoder_ckpt_path, decoder.TENSOR_NAMES
102
104
  )
103
105
  decoder_loader.load(decoder_model, strict=False)
106
+ decoder_model.eval()
104
107
 
105
108
  # TODO(yichunk): enable image encoder conversion
106
109
  # if encoder_ckpt_path is not None:
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
@@ -0,0 +1,269 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common building blocks for a GPU-specific Attention layer.
17
+
18
+ This is a temporary implemenation for the GPU. It is subject to change/removal
19
+ at any time.
20
+ """
21
+
22
+ from typing import Optional, Tuple, Union
23
+
24
+ from ai_edge_torch.generative.layers import builder
25
+ from ai_edge_torch.generative.layers import lora as lora_utils
26
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
27
+ from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
30
+ import torch
31
+ from torch import nn
32
+
33
+
34
+ class TransformerBlock(nn.Module):
35
+
36
+ def __init__(
37
+ self,
38
+ config: cfg.TransformerBlockConfig,
39
+ model_config: cfg.ModelConfig,
40
+ ) -> None:
41
+ """Initialize an instance of the TransformerBlock.
42
+
43
+ Args:
44
+ config (cfg.TransformerBlockConfig): the configuration object for this
45
+ transformer block.
46
+ model_config (cfg.ModelConfig): the configuration object for the model
47
+ this transformer block belongs to.
48
+ """
49
+ super().__init__()
50
+ self.pre_atten_norm = builder.build_norm(
51
+ model_config.embedding_dim,
52
+ config.pre_attention_norm_config,
53
+ )
54
+ self.atten_func = CausalSelfAttention(
55
+ model_config.batch_size,
56
+ model_config.embedding_dim,
57
+ config.attn_config,
58
+ model_config.enable_hlfb,
59
+ )
60
+ self.post_atten_norm = builder.build_norm(
61
+ model_config.embedding_dim,
62
+ config.post_attention_norm_config,
63
+ )
64
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
65
+ self.config = config
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
71
+ mask: Optional[torch.Tensor] = None,
72
+ input_pos: Optional[torch.Tensor] = None,
73
+ kv_cache: kv_utils.KVCacheEntryBase = None,
74
+ lora: Optional[lora_utils.LoRAEntry] = None,
75
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
76
+ """Forward function of the TransformerBlock.
77
+
78
+ Args:
79
+ x (torch.Tensor): the input tensor.
80
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
81
+ mask (torch.Tensor): the optional mask tensor.
82
+ input_pos (torch.Tensor): the optional input position tensor.
83
+ kv_cache (KVCacheEntryBase): the optional kv cache entry.
84
+ lora (LoRAEntry): the optional lora entry.
85
+
86
+ Returns:
87
+ output activation from this transformer block, and updated kv cache (if
88
+ passed in).
89
+ """
90
+ kv = None
91
+ if self.config.parallel_residual:
92
+ x_norm = self.pre_atten_norm(x)
93
+ atten_func_out = self.atten_func(
94
+ x_norm, rope, mask, input_pos, kv_cache, lora
95
+ )
96
+ if kv_cache is None:
97
+ attn_out = atten_func_out
98
+ else:
99
+ attn_out, kv = atten_func_out
100
+ ff_out = self.ff(x_norm)
101
+ output = x + attn_out + ff_out
102
+ else:
103
+ x_norm = self.pre_atten_norm(x)
104
+ atten_func_out = self.atten_func(
105
+ x_norm, rope, mask, input_pos, kv_cache, lora
106
+ )
107
+ if kv_cache is None:
108
+ attn_out = atten_func_out
109
+ else:
110
+ attn_out, kv = atten_func_out
111
+ x = x + attn_out
112
+ x_norm = self.post_atten_norm(x)
113
+ output = x + self.ff(x_norm)
114
+
115
+ return output if kv is None else (output, kv)
116
+
117
+
118
+ class CausalSelfAttention(nn.Module):
119
+
120
+ def __init__(
121
+ self,
122
+ batch_size: int,
123
+ dim: int,
124
+ config: cfg.AttentionConfig,
125
+ enable_hlfb: bool,
126
+ ) -> None:
127
+ """Initialize an instance of CausalSelfAttention.
128
+
129
+ Args:
130
+ batch_size (int): batch size of the input tensor.
131
+ dim (int): causal attention's input/output dimmension.
132
+ config (cfg.AttentionConfig): attention specific configurations.
133
+ enable_hlfb (bool): whether hlfb is enabled or not.
134
+ """
135
+ super().__init__()
136
+ self.kv_cache = None
137
+ self.batch_size = batch_size
138
+ qkv_shape = (
139
+ config.num_heads + 2 * config.num_query_groups
140
+ ) * config.head_dim
141
+ output_shape = config.num_heads * config.head_dim
142
+ # Key, query, value projections for all heads.
143
+ self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
144
+ self.output_projection = nn.Linear(
145
+ output_shape, dim, bias=config.output_proj_use_bias
146
+ )
147
+ self.query_norm = builder.build_norm(
148
+ config.head_dim, config.query_norm_config
149
+ )
150
+ self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
151
+ self.config = config
152
+ self.enable_hlfb = enable_hlfb
153
+ self.sdpa_func = sdpa.scaled_dot_product_attention
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
159
+ mask: Optional[torch.Tensor] = None,
160
+ input_pos: Optional[torch.Tensor] = None,
161
+ kv_cache: Optional[kv_utils.KVCacheEntryBase] = None,
162
+ lora: Optional[lora_utils.LoRAEntry] = None,
163
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
164
+ """Forward function of the CausalSelfAttention layer, which can support
165
+
166
+ MQA, GQA and MHA.
167
+
168
+ Args:
169
+ x (torch.Tensor): the input tensor.
170
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
171
+ mask (torch.Tensor): the optional mask tensor.
172
+ input_pos (torch.Tensor): the optional input position tensor.
173
+ kv_cache (KVCacheEntryBase): the KV cache entry corresponding to this
174
+ module.
175
+ lora (LoRAEntry): the optional lora entry.
176
+
177
+ Returns:
178
+ output activation from this self attention layer, and the updated
179
+ KV Cach Entry (if passed in).
180
+ """
181
+ # Batch size, sequence length, embedding dimensionality.
182
+ B, T, E = x.size()
183
+ assert B == self.batch_size, (
184
+ "batch size of input tensor must match with the batch size specified in"
185
+ " the model configuration."
186
+ )
187
+
188
+ qkv = self.qkv_projection(x)
189
+
190
+ # Assemble into a number of query groups to support MHA, MQA and GQA.
191
+ q_per_kv = self.config.num_heads // self.config.num_query_groups
192
+ # Each group has >=1 queries, 1 key, and 1 value.
193
+ if self.config.qkv_transpose_before_split:
194
+ qkv = qkv.view(B, T, -1, self.config.head_dim)
195
+ q, k, v = qkv.split(
196
+ (
197
+ q_per_kv * self.config.num_query_groups,
198
+ self.config.num_query_groups,
199
+ self.config.num_query_groups,
200
+ ),
201
+ dim=-2,
202
+ )
203
+ else:
204
+ qkv = qkv.view(B, T, self.config.num_query_groups, -1)
205
+ q, k, v = qkv.split(
206
+ (
207
+ q_per_kv * self.config.head_dim,
208
+ self.config.head_dim,
209
+ self.config.head_dim,
210
+ ),
211
+ dim=-1,
212
+ )
213
+
214
+ if lora is not None:
215
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
216
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
217
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
218
+
219
+ q = self.query_norm(q)
220
+ k = self.key_norm(k)
221
+
222
+ q = q.reshape(B, T, -1, self.config.head_dim)
223
+ k = k.reshape(B, T, -1, self.config.head_dim)
224
+ v = v.reshape(B, T, -1, self.config.head_dim)
225
+
226
+ if rope is not None:
227
+ # Compute rotary positional embedding for query and key.
228
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
229
+ cos, sin = rope
230
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
231
+
232
+ # Transpose k/v to specific layout for GPU implementation.
233
+ b, _, n, h = q.shape
234
+ g = n // self.config.num_query_groups
235
+ # btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
236
+ q = q.permute(0, 2, 1, 3).reshape(
237
+ 1, b * self.config.num_query_groups, g * T, h
238
+ )
239
+
240
+ k = k.permute(0, 2, 1, 3).reshape(
241
+ 1, -1, T, self.config.head_dim
242
+ ) # 1, bk, s, h
243
+ v = v.permute(0, 2, 3, 1).reshape(
244
+ 1, -1, self.config.head_dim, T
245
+ ) # 1, bk, h, s
246
+
247
+ if kv_cache is not None:
248
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
249
+ k, v = kv_cache.k_cache, kv_cache.v_cache
250
+
251
+ sdpa_out = self.sdpa_func(
252
+ kv_cache,
253
+ q,
254
+ k,
255
+ v,
256
+ self.config.head_dim,
257
+ mask=mask,
258
+ softcap=self.config.logit_softcap,
259
+ ) # 1, bk, gt, h
260
+ sdpa_out = (
261
+ sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
262
+ )
263
+
264
+ # Compute the output projection.
265
+ y = self.output_projection(sdpa_out)
266
+ if lora is not None:
267
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
268
+
269
+ return y if kv_cache is None else (y, kv_cache)
@@ -0,0 +1,314 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions for KV Cache.
17
+
18
+ This is an experimental implementation and is subject to change at any time.
19
+ """
20
+
21
+ import dataclasses
22
+ from typing import List, Tuple
23
+
24
+ from ai_edge_torch import hlfb
25
+ from ai_edge_torch.generative.layers import model_config
26
+ from ai_edge_torch.generative.layers.experimental import types as types
27
+ from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.utils._pytree as pytree
31
+
32
+ BATCH_SIZE = 1
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class KVCacheEntryBase:
37
+ """A single cache entry that includes K and V caches.
38
+
39
+ The chaches are built based on the provided config with the shape of
40
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
41
+ """
42
+
43
+ k_cache: torch.Tensor
44
+ v_cache: torch.Tensor
45
+
46
+ @classmethod
47
+ def _from_model_config(
48
+ cls,
49
+ kv_cache_max: int,
50
+ config: model_config.AttentionConfig,
51
+ k_shape: Tuple,
52
+ v_shape: Tuple,
53
+ dtype: torch.dtype = torch.float32,
54
+ device: torch.device = None,
55
+ ) -> "KVCacheEntryBase":
56
+ """Build an instance of the class based on model config."""
57
+ k = torch.zeros(k_shape, dtype=dtype, device=device)
58
+ v = torch.zeros(v_shape, dtype=dtype, device=device)
59
+ obj = cls(k_cache=k, v_cache=v)
60
+ return obj
61
+
62
+ @classmethod
63
+ def from_model_config(
64
+ cls,
65
+ kv_cache_max: int,
66
+ config: model_config.AttentionConfig,
67
+ dtype: torch.dtype = torch.float32,
68
+ device: torch.device = None,
69
+ ) -> "KVCacheEntryBase":
70
+ """Build an instance of the class based on model config."""
71
+ shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72
+ return cls._from_model_config(
73
+ kv_cache_max, config, shape, shape, dtype, device
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class KVCacheEntryBTNH(KVCacheEntryBase):
79
+ k_type = types.BTNH()
80
+ v_type = types.BTNH()
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class KVCacheEntryTransposed(KVCacheEntryBase):
85
+
86
+ k_type = types.BNTH()
87
+ v_type = types.BNHT()
88
+
89
+ @classmethod
90
+ def from_model_config(
91
+ cls,
92
+ kv_cache_max: int,
93
+ config: model_config.AttentionConfig,
94
+ dtype: torch.dtype = torch.float32,
95
+ device: torch.device = None,
96
+ ) -> "KVCacheEntryBase":
97
+ """Build an instance of the class based on model config."""
98
+ num_kv_heads = config.num_query_groups
99
+ k_shape = (
100
+ 1,
101
+ BATCH_SIZE * num_kv_heads,
102
+ kv_cache_max,
103
+ config.head_dim,
104
+ ) # 1, bk, s, h
105
+ v_shape = (
106
+ 1,
107
+ BATCH_SIZE * num_kv_heads,
108
+ config.head_dim,
109
+ kv_cache_max,
110
+ ) # 1, bk, h, s
111
+ return cls._from_model_config(
112
+ kv_cache_max, config, k_shape, v_shape, dtype, device
113
+ )
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class KVCacheBase:
118
+ """A utility class for holding KV cache entries per layer."""
119
+
120
+ caches: Tuple[KVCacheEntryBase, ...]
121
+
122
+ @classmethod
123
+ def _from_model_config(
124
+ cls,
125
+ kv_entry_cls,
126
+ config: model_config.ModelConfig,
127
+ dtype: torch.dtype = torch.float32,
128
+ device: torch.device = None,
129
+ ) -> "KVCacheBase":
130
+ caches = [
131
+ kv_entry_cls.from_model_config(
132
+ config.kv_cache_max,
133
+ config.block_config(idx).attn_config,
134
+ dtype,
135
+ device,
136
+ )
137
+ for idx in range(config.num_layers)
138
+ ]
139
+ obj = cls(caches=tuple(caches))
140
+ return obj
141
+
142
+ @classmethod
143
+ def from_model_config(
144
+ cls,
145
+ config: model_config.ModelConfig,
146
+ dtype: torch.dtype = torch.float32,
147
+ device: torch.device = None,
148
+ ) -> "KVCacheBase":
149
+ """Build an instance of the class based on model config.
150
+
151
+ Args:
152
+ config (ModelConfig): Model config used for building the cache.
153
+ dtype (torch.dtype, optional): The data type of the cache tensor.
154
+ Defaults to torch.float32.
155
+ device (torch.device, optional): The device placement of the cache
156
+ tensors. Defaults to None.
157
+
158
+ Returns:
159
+ KVCacheBase: The created cache object.
160
+ """
161
+ return cls._from_model_config(
162
+ KVCacheEntryBase, config=config, dtype=dtype, device=device
163
+ )
164
+
165
+ def flatten(self) -> List[torch.Tensor]:
166
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
167
+ flattened, _ = _flatten_kvc(self)
168
+ return flattened
169
+
170
+
171
+ @dataclasses.dataclass
172
+ class KVCacheBTNH(KVCacheBase):
173
+
174
+ @classmethod
175
+ def from_model_config(
176
+ cls,
177
+ config: model_config.ModelConfig,
178
+ dtype: torch.dtype = torch.float32,
179
+ device: torch.device = None,
180
+ ) -> "KVCacheBTNH":
181
+ return cls._from_model_config(
182
+ KVCacheEntryBTNH, config=config, dtype=dtype, device=device
183
+ )
184
+
185
+
186
+ @dataclasses.dataclass
187
+ class KVCacheTransposed(KVCacheBase):
188
+
189
+ @classmethod
190
+ def from_model_config(
191
+ cls,
192
+ config: model_config.ModelConfig,
193
+ dtype: torch.dtype = torch.float32,
194
+ device: torch.device = None,
195
+ ) -> "KVCacheBTNH":
196
+ return cls._from_model_config(
197
+ KVCacheEntryTransposed, config=config, dtype=dtype, device=device
198
+ )
199
+
200
+
201
+ def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
202
+ flattened = []
203
+ flat_names = []
204
+ none_names = []
205
+ for i, kv_entry in enumerate(kvc.caches):
206
+ flattened.append(kv_entry.k_cache)
207
+ flat_names.append(f"k_{i}")
208
+ flattened.append(kv_entry.v_cache)
209
+ flat_names.append(f"v_{i}")
210
+ return flattened, [flat_names, none_names]
211
+
212
+
213
+ def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
214
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
215
+ return [
216
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
217
+ ], flat_names
218
+
219
+
220
+ def _unflatten_kvc(
221
+ values: List[torch.Tensor], context: Tuple[List, List]
222
+ ) -> KVCacheBase:
223
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
224
+ num_layers = len(values) // 2
225
+ flat_names = context[0]
226
+ kv_entries = []
227
+ for i in range(num_layers):
228
+ k_cache_idx = flat_names.index(f"k_{i}")
229
+ v_cache_idx = flat_names.index(f"v_{i}")
230
+ kv_entries.append(
231
+ KVCacheEntryBase(
232
+ k_cache=values[k_cache_idx], v_cache=values[v_cache_idx]
233
+ )
234
+ )
235
+ obj = KVCacheBase(tuple(kv_entries))
236
+ return obj
237
+
238
+
239
+ pytree.register_pytree_node(
240
+ KVCacheTransposed,
241
+ _flatten_kvc,
242
+ _unflatten_kvc,
243
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
244
+ serialized_type_name="",
245
+ )
246
+
247
+ pytree.register_pytree_node(
248
+ KVCacheBase,
249
+ _flatten_kvc,
250
+ _unflatten_kvc,
251
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
252
+ serialized_type_name="",
253
+ )
254
+
255
+
256
+ def update(
257
+ cache: KVCacheEntryBase,
258
+ input_pos: torch.Tensor,
259
+ k_slice: torch.Tensor,
260
+ v_slice: torch.Tensor,
261
+ use_dus: bool = True,
262
+ ) -> KVCacheEntryBase:
263
+ """Out of place update of Cache buffer.
264
+
265
+ Args:
266
+ cache (KVCacheEntryBase): The original cache buffer.
267
+ input_pos (torch.Tensor): The update slice positions.
268
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
269
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
270
+
271
+ Returns:
272
+ KVCacheEntryBase: The updated KVCacheBase entry based on the passed
273
+ inputs.
274
+ """
275
+ update_kv_cache = _update_kv_impl
276
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
277
+
278
+
279
+ def _get_slice_indices(
280
+ positions: torch.Tensor, cache_dim: int, ts_idx: int
281
+ ) -> torch.Tensor:
282
+ """Returns the slice indices."""
283
+ positions = positions.float()[0].reshape(
284
+ 1,
285
+ )
286
+
287
+ zeros = torch.zeros((1,), dtype=torch.float32)
288
+ indices = []
289
+ for i in range(cache_dim):
290
+ if i == ts_idx:
291
+ indices.append(positions)
292
+ else:
293
+ indices.append(zeros)
294
+ slice_indices = torch.cat(indices, dim=0)
295
+ slice_indices = slice_indices.int()
296
+ return slice_indices
297
+
298
+
299
+ def _update_kv_impl(
300
+ cache: KVCacheEntryTransposed,
301
+ input_pos: torch.Tensor,
302
+ k_slice: torch.Tensor,
303
+ v_slice: torch.Tensor,
304
+ ) -> KVCacheEntryTransposed:
305
+ """Update the cache buffer with High Level Function Boundary annotation."""
306
+ cache_dim = 4
307
+ k_ts_idx = 2
308
+ v_ts_idx = 3
309
+ positions = input_pos.clone()
310
+ k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311
+ v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312
+ k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313
+ v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
314
+ return KVCacheEntryTransposed(k, v)
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Implements scaled dot product attention. This is experimental and
16
+ # GPU-specific code.
17
+
18
+ import math
19
+ from typing import Optional
20
+
21
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers.experimental import types
23
+ from ai_edge_torch.generative.utilities import bmm_4d as bmm_lib
24
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
+ from multipledispatch import dispatch
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+
30
+ def scaled_dot_product_attention(
31
+ kv: kv_utils.KVCacheBase,
32
+ query: torch.Tensor,
33
+ key: torch.Tensor,
34
+ value: torch.Tensor,
35
+ head_size: int,
36
+ mask: Optional[torch.Tensor] = None,
37
+ scale: Optional[float] = None,
38
+ softcap: Optional[float] = None,
39
+ ):
40
+ if hasattr(kv, "k_type") and hasattr(kv, "v_type"):
41
+ return _sdpa(
42
+ kv.k_type,
43
+ kv.v_type,
44
+ query=query,
45
+ key=key,
46
+ value=value,
47
+ head_size=head_size,
48
+ mask=mask,
49
+ scale=scale,
50
+ softcap=softcap,
51
+ )
52
+ raise ValueError(
53
+ f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
54
+ f" {type(kv.caches[0].v_type)} not supported."
55
+ )
56
+
57
+
58
+ @dispatch(types.BNTH, types.BNHT)
59
+ def _sdpa(k_type, v_type, *args, **kwargs):
60
+ query = kwargs["query"]
61
+ key = kwargs["key"]
62
+ value = kwargs["value"]
63
+ head_size = kwargs["head_size"]
64
+ mask = kwargs.get("mask", None)
65
+ scale = kwargs.get("scale", None)
66
+ softcap = kwargs.get("softcap", None)
67
+
68
+ if scale is None:
69
+ scale = 1.0 / math.sqrt(head_size)
70
+
71
+ query = query * scale
72
+
73
+ assert mask is not None, "Mask should not be None!"
74
+ t = mask.shape[2]
75
+
76
+ logits = bmm_lib.bmm_4d(query, key)
77
+
78
+ _, bk, gt, s = logits.shape
79
+ g = gt // t
80
+ logits = logits.reshape((bk, g, t, s))
81
+ if softcap is not None:
82
+ logits = torch.tanh(logits / softcap)
83
+ logits = logits * softcap
84
+
85
+ padded_logits = logits + mask
86
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
87
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
88
+
89
+ encoded = bmm_lib.bmm_4d(probs, value)
90
+
91
+ return encoded # 1, bk, gt, h
92
+
93
+
94
+ @dispatch(object, object)
95
+ def _sdpa(k_type, v_type, *args, **kwargs):
96
+
97
+ raise ValueError(f"No implementations for k={k_type} and v={v_type}")
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A listing of types describes the K and V tensors in KV caches.
16
+
17
+ import enum
18
+ from enum import Enum, auto
19
+ from typing import Tuple
20
+ from torch import nn
21
+
22
+
23
+ @enum.unique
24
+ class TensorDims(Enum):
25
+ BATCH = enum.auto()
26
+ SEQUENCE = enum.auto()
27
+ NUM_HEADS = enum.auto()
28
+ HEAD_DIM = enum.auto()
29
+ MODEL_DIM = enum.auto() # often num_heads * head_dim
30
+
31
+
32
+ DIM_TO_LETTER = {
33
+ TensorDims.BATCH: 'B',
34
+ TensorDims.SEQUENCE: 'T',
35
+ TensorDims.NUM_HEADS: 'N',
36
+ TensorDims.HEAD_DIM: 'H',
37
+ TensorDims.MODEL_DIM: 'D',
38
+ }
39
+
40
+
41
+ class TensorDimensionMeta(type):
42
+ """Metaclass to create classes representing an order of tensor dimensions."""
43
+
44
+ def __new__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
45
+ """Creates a new class with the given name and tensor dimension order.
46
+
47
+ Args:
48
+ name: Name of the new class.
49
+ bases: Base classes for the new class.
50
+ attrs: Attributes for the new class.
51
+ dimensions: A tuple of TensorDims defining the order.
52
+ """
53
+
54
+ attrs['dimensions'] = (
55
+ dimensions # Store the dimensions as a class attribute
56
+ )
57
+ return super().__new__(cls, name, bases, attrs)
58
+
59
+ def __init__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
60
+ super().__init__(name, bases, attrs)
61
+
62
+ def __repr__(cls):
63
+ return f'{cls.__name__}'
64
+
65
+
66
+ def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
67
+ """Creates a TensorDimensionMeta class with the specified dimensions.
68
+
69
+ Args:
70
+ dimensions: A tuple of TensorDims.
71
+
72
+ Returns:
73
+ A new class representing the tensor dimension order.
74
+ """
75
+ name = ''.join(DIM_TO_LETTER[d] for d in dims)
76
+ # Derive from nn.Module for torch tracing compatiblity.
77
+ return TensorDimensionMeta(name, (nn.Module,), {}, dimensions=dims)
78
+
79
+
80
+ BTNH = create_tensor_dimension_order_class((
81
+ TensorDims.BATCH,
82
+ TensorDims.SEQUENCE,
83
+ TensorDims.NUM_HEADS,
84
+ TensorDims.HEAD_DIM,
85
+ ))
86
+ BNTH = create_tensor_dimension_order_class((
87
+ TensorDims.BATCH,
88
+ TensorDims.NUM_HEADS,
89
+ TensorDims.SEQUENCE,
90
+ TensorDims.HEAD_DIM,
91
+ ))
92
+ BNHT = create_tensor_dimension_order_class((
93
+ TensorDims.BATCH,
94
+ TensorDims.NUM_HEADS,
95
+ TensorDims.HEAD_DIM,
96
+ TensorDims.SEQUENCE,
97
+ ))
@@ -17,6 +17,7 @@
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
+ from ai_edge_torch.generative.examples.deepseek import deepseek
20
21
  from ai_edge_torch.generative.examples.gemma import gemma1
21
22
  from ai_edge_torch.generative.examples.gemma import gemma2
22
23
  from ai_edge_torch.generative.examples.llama import llama
@@ -150,16 +151,15 @@ class TestModelConversion(googletest.TestCase):
150
151
  ai_edge_torch.config.in_oss,
151
152
  reason="tests with custom ops are not supported in oss",
152
153
  )
153
-
154
154
  def test_smollm2(self):
155
155
  config = smollm.get_fake_model_config_v2()
156
156
  pytorch_model = smollm.SmolLM2(config).eval()
157
157
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+
158
159
  @googletest.skipIf(
159
160
  ai_edge_torch.config.in_oss,
160
161
  reason="tests with custom ops are not supported in oss",
161
162
  )
162
-
163
163
  def test_openelm(self):
164
164
  config = openelm.get_fake_model_config()
165
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -174,6 +174,15 @@ class TestModelConversion(googletest.TestCase):
174
174
  pytorch_model = qwen.Qwen(config).eval()
175
175
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
176
176
 
177
+ @googletest.skipIf(
178
+ ai_edge_torch.config.in_oss,
179
+ reason="tests with custom ops are not supported in oss",
180
+ )
181
+ def test_deepseek(self):
182
+ config = deepseek.get_fake_model_config()
183
+ pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
184
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
185
+
177
186
  @googletest.skipIf(
178
187
  ai_edge_torch.config.in_oss,
179
188
  reason="tests with custom ops are not supported in oss",
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Optional, Union
20
20
  from ai_edge_torch._convert import converter as converter_utils
21
21
  from ai_edge_torch.generative.layers import lora as lora_utils
22
- import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
22
  import ai_edge_torch.generative.layers.model_config as cfg
24
23
  from ai_edge_torch.generative.quantize import quant_recipes
25
24
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -151,9 +150,21 @@ def _export_helper(
151
150
  else None
152
151
  )
153
152
 
153
+ if export_config.prefill_mask is None:
154
+ prefill_masks = None
155
+ elif isinstance(export_config.prefill_mask, torch.Tensor):
156
+ prefill_masks = [export_config.prefill_mask]
157
+ elif isinstance(export_config.prefill_mask, list):
158
+ prefill_masks = export_config.prefill_mask
159
+ else:
160
+ raise ValueError('Prefill masks unrecognized.')
161
+
162
+ if prefill_masks:
163
+ assert len(prefill_masks) == len(prefill_seq_lens)
164
+
154
165
  decode_token = torch.tensor([[0]], dtype=torch.int)
155
166
  decode_input_pos = torch.tensor([0], dtype=torch.int)
156
- kv = kv_utils.KVCache.from_model_config(config)
167
+ kv = export_config.kvcache_cls.from_model_config(config)
157
168
 
158
169
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
159
170
 
@@ -174,8 +185,8 @@ def _export_helper(
174
185
  'input_pos': prefill_input_pos,
175
186
  'kv_cache': kv,
176
187
  }
177
- if export_config.prefill_mask is not None:
178
- sample_kwargs['mask'] = export_config.prefill_mask
188
+ if prefill_masks is not None:
189
+ sample_kwargs['mask'] = prefill_masks[i]
179
190
 
180
191
  if lora is not None:
181
192
  prefill_signature_name += f'_lora_r{lora.get_rank()}'
@@ -17,7 +17,7 @@
17
17
 
18
18
  import copy
19
19
  from dataclasses import dataclass
20
- from typing import Optional, Tuple
20
+ from typing import List, Optional, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.layers import attention
23
23
  from ai_edge_torch.generative.layers import builder
@@ -56,8 +56,10 @@ class ExportConfig:
56
56
  # When False, only decode signatures will produce output.
57
57
  output_logits_on_prefill: bool = False
58
58
  # Attention masks given as inputs to the model.
59
- prefill_mask: Optional[torch.Tensor] = None
60
- decode_mask: Optional[torch.Tensor] = None
59
+ prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
60
+ decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
61
+ # The KV Cache class for K and V buffers in attention.
62
+ kvcache_cls: type = kv_utils.KVCache
61
63
 
62
64
 
63
65
  class DecoderOnlyModel(nn.Module):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250124"
16
+ __version__ = "0.3.0.dev20250125"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250124
3
+ Version: 0.3.0.dev20250125
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=xloGd_dX0MD8k-quT07WLlEN1zIGVtCKu6xBSvjofrc,706
5
+ ai_edge_torch/version.py,sha256=yuz53SwRvngiQ41D-VX7MPmVGe-Vi-UR3v12E-o3P4I,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -49,6 +49,10 @@ ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0H
49
49
  ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
50
50
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
51
51
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
52
+ ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
53
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=I5eA-XfFdHjYwDsLIjn23T2e-IgnSCQ129-5DOU8j44,2532
54
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
55
+ ai_edge_torch/generative/examples/deepseek/verify.py,sha256=sDYBhmE_CeZw5iLIQ7rJNGLjhcTyKUQGdg7_QQBh9WM,2398
52
56
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
54
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
@@ -97,7 +101,7 @@ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68
97
101
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
98
102
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
99
103
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
100
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=sB_7-PVri8PxKnFG7c8GsTGyrxGEda-oZwGyyScTL3o,5239
104
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=GtwKAByEk0ENGEWbUmC2mAAPkbLZ3M5xH1HIToyu8QE,5307
101
105
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=sQKQ-k6H9kG2brgwLsktjCMeN2h0POyfMP6iNsPNKWc,16271
102
106
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6W58LxmHHkz2ctgpknQkyoDANZAnE9Byp_svfqLpQf0,34793
103
107
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
@@ -134,6 +138,11 @@ ai_edge_torch/generative/layers/model_config.py,sha256=ZVRWEGw1BnLbLCuoR71kWGqQt
134
138
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
135
139
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
136
140
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
141
+ ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
142
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
143
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
144
+ ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
145
+ ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
137
146
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
138
147
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
139
148
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -151,15 +160,15 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
151
160
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
152
161
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
153
162
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
154
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF_xz0UR3kD3x74ELsmJetbQnmv7-9gyQ,12473
163
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=AJs_ARfWUqwuFRwYtQQOLd87CiD4mUDwAhq885cqc4Q,12875
155
164
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
156
165
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
157
166
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
158
167
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
159
- ai_edge_torch/generative/utilities/converter.py,sha256=QIYxT-zATMzsD3LG-keRkxpJqDKXkbil4Se1KXthWFg,7726
168
+ ai_edge_torch/generative/utilities/converter.py,sha256=6siSpCvH_cLV-eP40lkF_AqjBpYv68xeMRQ722fKgE0,8065
160
169
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
161
170
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
162
- ai_edge_torch/generative/utilities/model_builder.py,sha256=aXigoFEMLAKk7HQuWJM5ILs3igA4z2VH64ZCzCuBhDE,6671
171
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
163
172
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
164
173
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
165
174
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -213,8 +222,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
213
222
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
214
223
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
215
224
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
216
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
217
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/METADATA,sha256=WKVqBXJtXvMv3JfqtYKcl1GFgKrKtSbZ-tJAol5PPHk,1966
218
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
219
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
220
- ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/RECORD,,
225
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/METADATA,sha256=BkUH2iAinJYGmBLTMdeYSpihXAHY_mBOkeprZLPaDGk,1966
227
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/RECORD,,