ai-edge-torch-nightly 0.4.0.dev20250227__py3-none-any.whl → 0.4.0.dev20250301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/phi/phi3.py +2 -5
- ai_edge_torch/generative/examples/phi/phi4.py +165 -0
- ai_edge_torch/generative/examples/phi/verify_phi4.py +69 -0
- ai_edge_torch/generative/layers/experimental/attention.py +0 -8
- ai_edge_torch/generative/layers/experimental/kv_cache.py +45 -31
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250301.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250301.dist-info}/RECORD +13 -10
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250301.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250301.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250301.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Phi-4 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.phi import phi4
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi4'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi4',
|
40
|
+
'The prefix of the output tflite model name.',
|
41
|
+
)
|
42
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
+
'prefill_seq_lens',
|
44
|
+
(8, 64, 128, 256, 512, 1024),
|
45
|
+
'List of the maximum sizes of prefill input tensors.',
|
46
|
+
)
|
47
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
+
'kv_cache_max_len',
|
49
|
+
1280,
|
50
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
+
)
|
52
|
+
_QUANTIZE = flags.DEFINE_bool(
|
53
|
+
'quantize',
|
54
|
+
True,
|
55
|
+
'Whether the model should be quantized.',
|
56
|
+
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
def main(_):
|
65
|
+
pytorch_model = phi4.build_model(
|
66
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
67
|
+
)
|
68
|
+
converter.convert_to_tflite(
|
69
|
+
pytorch_model,
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
72
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
73
|
+
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
75
|
+
export_config=ExportConfig(),
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
if __name__ == '__main__':
|
80
|
+
app.run(main)
|
@@ -136,10 +136,7 @@ def _build_phi3_rope(
|
|
136
136
|
|
137
137
|
class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
138
138
|
"""A Phi-3.5 model built from the Edge Generative API layers."""
|
139
|
-
|
140
|
-
def __init__(self, config: cfg.ModelConfig):
|
141
|
-
super().__init__(config)
|
142
|
-
attn_config = self.config.block_config(0).attn_config
|
139
|
+
pass
|
143
140
|
|
144
141
|
|
145
142
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -150,7 +147,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
150
147
|
is 1024.
|
151
148
|
|
152
149
|
Returns:
|
153
|
-
The model config for a Phi-
|
150
|
+
The model config for a Phi-3.5 model.
|
154
151
|
"""
|
155
152
|
attn_config = cfg.AttentionConfig(
|
156
153
|
num_heads=32,
|
@@ -0,0 +1,165 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a Phi-4 model up to 4K tokens, not to 128K tokens."""
|
17
|
+
|
18
|
+
from functools import partial
|
19
|
+
import math
|
20
|
+
from typing import Tuple
|
21
|
+
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
|
+
import torch
|
26
|
+
|
27
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
28
|
+
ff_up_proj="model.layers.{}.mlp.gate_up_proj",
|
29
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
30
|
+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
31
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
32
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
33
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
34
|
+
embedding="model.embed_tokens",
|
35
|
+
final_norm="model.norm",
|
36
|
+
)
|
37
|
+
|
38
|
+
# max_position_embeddings / original_max_position_embeddings in Phi-4 config.
|
39
|
+
ROPE_SCALE_FACTOR = 32
|
40
|
+
|
41
|
+
# ROPE short factor in Phi-4 config. According to LOPE paper and its code in
|
42
|
+
# https://github.com/microsoft/LongRoPE, these values had been searched with
|
43
|
+
# min=1.0, step-0.01 to optimize the errors of sample dataset.
|
44
|
+
ROPE_SHORT_FACTOR = [1.0] * 48
|
45
|
+
|
46
|
+
|
47
|
+
def _build_phi4_rope(
|
48
|
+
input_pos: int,
|
49
|
+
n_elem: int,
|
50
|
+
base: int,
|
51
|
+
condense_ratio: int,
|
52
|
+
dtype: torch.dtype,
|
53
|
+
device: torch.device,
|
54
|
+
theta_factors: torch.Tensor,
|
55
|
+
scale: float,
|
56
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
57
|
+
"""Computes Rotary Positional Embeddings for Phi-4 model.
|
58
|
+
|
59
|
+
It's a modified version of attn_utils.build_rope_cache with additional
|
60
|
+
arguments for Phi-4 model. It precompute Rotary Positional Embedding Sin and
|
61
|
+
Cos values with scaling factors for quick lookup during the inference.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
input_pos (torch.Tensor): the given input sequence positions
|
65
|
+
n_elem (int): Each sequence's dimmension.
|
66
|
+
base (int, optional): Rope base value.
|
67
|
+
condense_ratio (int, optional): The ratio by which sequence indicies are
|
68
|
+
condensed.
|
69
|
+
dtype (torch.dtype, optional): Output tensor's data type.
|
70
|
+
device (torch.device, optional): Output tensor's data type.
|
71
|
+
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
|
72
|
+
to scale the theta values.
|
73
|
+
scale (float, optional): A float used to scale the rope values.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
77
|
+
"""
|
78
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
79
|
+
theta = theta / theta_factors
|
80
|
+
seq_idx = input_pos / condense_ratio
|
81
|
+
idx_theta = torch.outer(seq_idx, theta)
|
82
|
+
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
83
|
+
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
|
84
|
+
return cos, sin
|
85
|
+
|
86
|
+
|
87
|
+
class Phi4Mini(model_builder.DecoderOnlyModel):
|
88
|
+
"""A Phi-4 model built from the Edge Generative API layers."""
|
89
|
+
pass
|
90
|
+
|
91
|
+
|
92
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
93
|
+
"""Returns the model config for a Phi-4 model.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
97
|
+
is 1024.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
The model config for a Phi-4 model.
|
101
|
+
"""
|
102
|
+
attn_config = cfg.AttentionConfig(
|
103
|
+
num_heads=24,
|
104
|
+
head_dim=128,
|
105
|
+
num_query_groups=8,
|
106
|
+
rotary_base=10000,
|
107
|
+
rotary_percentage=0.75,
|
108
|
+
qkv_transpose_before_split=True,
|
109
|
+
)
|
110
|
+
ff_config = cfg.FeedForwardConfig(
|
111
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
112
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
113
|
+
intermediate_size=8192,
|
114
|
+
)
|
115
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
116
|
+
block_config = cfg.TransformerBlockConfig(
|
117
|
+
attn_config=attn_config,
|
118
|
+
ff_config=ff_config,
|
119
|
+
pre_attention_norm_config=norm_config,
|
120
|
+
post_attention_norm_config=norm_config,
|
121
|
+
)
|
122
|
+
|
123
|
+
max_seq_len = 4096
|
124
|
+
# Create the RoPE callable
|
125
|
+
build_rope = partial(
|
126
|
+
_build_phi4_rope,
|
127
|
+
condense_ratio=1,
|
128
|
+
dtype=torch.float32,
|
129
|
+
device=torch.device("cpu"),
|
130
|
+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
131
|
+
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
|
132
|
+
)
|
133
|
+
|
134
|
+
config = cfg.ModelConfig(
|
135
|
+
vocab_size=200064,
|
136
|
+
num_layers=32,
|
137
|
+
max_seq_len=max_seq_len,
|
138
|
+
kv_cache_max_len=kv_cache_max_len,
|
139
|
+
embedding_dim=3072,
|
140
|
+
block_configs=block_config,
|
141
|
+
final_norm_config=norm_config,
|
142
|
+
enable_hlfb=True,
|
143
|
+
build_rope=build_rope,
|
144
|
+
)
|
145
|
+
return config
|
146
|
+
|
147
|
+
|
148
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
149
|
+
config = get_model_config(kv_cache_max_len)
|
150
|
+
config.vocab_size = 128
|
151
|
+
config.num_layers = 2
|
152
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
153
|
+
# Phi-4 has only one block config.
|
154
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
155
|
+
return config
|
156
|
+
|
157
|
+
|
158
|
+
def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
159
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
160
|
+
return model_builder.build_decoder_only_model(
|
161
|
+
checkpoint_path=checkpoint_path,
|
162
|
+
config=get_model_config(**kwargs),
|
163
|
+
tensor_names=TENSOR_NAMES,
|
164
|
+
model_class=Phi4Mini,
|
165
|
+
)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Phi-4 model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.phi import phi4
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
|
29
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
30
|
+
"prompts",
|
31
|
+
"Instruct: Write an email about the weather Output:",
|
32
|
+
"The input prompts to generate answers.",
|
33
|
+
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def main(_):
|
42
|
+
checkpoint = "microsoft/Phi-4-mini-instruct"
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
46
|
+
# Locate the cached dir.
|
47
|
+
cached_config_file = transformers.utils.cached_file(
|
48
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
+
)
|
50
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
+
reauthored_model = phi4.build_model(reauthored_checkpoint)
|
53
|
+
|
54
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
56
|
+
|
57
|
+
verifier.verify_reauthored_model(
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
63
|
+
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
if __name__ == "__main__":
|
69
|
+
app.run(main)
|
@@ -52,7 +52,6 @@ class TransformerBlock(nn.Module):
|
|
52
52
|
config.pre_attention_norm_config,
|
53
53
|
)
|
54
54
|
self.atten_func = CausalSelfAttention(
|
55
|
-
model_config.batch_size,
|
56
55
|
model_config.embedding_dim,
|
57
56
|
config.attn_config,
|
58
57
|
model_config.enable_hlfb,
|
@@ -119,7 +118,6 @@ class CausalSelfAttention(nn.Module):
|
|
119
118
|
|
120
119
|
def __init__(
|
121
120
|
self,
|
122
|
-
batch_size: int,
|
123
121
|
dim: int,
|
124
122
|
config: cfg.AttentionConfig,
|
125
123
|
enable_hlfb: bool,
|
@@ -127,14 +125,12 @@ class CausalSelfAttention(nn.Module):
|
|
127
125
|
"""Initialize an instance of CausalSelfAttention.
|
128
126
|
|
129
127
|
Args:
|
130
|
-
batch_size (int): batch size of the input tensor.
|
131
128
|
dim (int): causal attention's input/output dimmension.
|
132
129
|
config (cfg.AttentionConfig): attention specific configurations.
|
133
130
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
134
131
|
"""
|
135
132
|
super().__init__()
|
136
133
|
self.kv_cache = None
|
137
|
-
self.batch_size = batch_size
|
138
134
|
qkv_shape = (
|
139
135
|
config.num_heads + 2 * config.num_query_groups
|
140
136
|
) * config.head_dim
|
@@ -180,10 +176,6 @@ class CausalSelfAttention(nn.Module):
|
|
180
176
|
"""
|
181
177
|
# Batch size, sequence length, embedding dimensionality.
|
182
178
|
B, T, E = x.size()
|
183
|
-
assert B == self.batch_size, (
|
184
|
-
"batch size of input tensor must match with the batch size specified in"
|
185
|
-
" the model configuration."
|
186
|
-
)
|
187
179
|
|
188
180
|
qkv = self.qkv_projection(x)
|
189
181
|
|
@@ -21,23 +21,19 @@ This is an experimental implementation and is subject to change at any time.
|
|
21
21
|
import dataclasses
|
22
22
|
from typing import List, Tuple
|
23
23
|
|
24
|
-
from ai_edge_torch import hlfb
|
25
24
|
from ai_edge_torch.generative.layers import model_config
|
26
|
-
from ai_edge_torch.generative.layers.experimental import types
|
27
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.layers.experimental import types
|
26
|
+
from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
|
28
27
|
import torch
|
29
|
-
import torch.nn as nn
|
30
28
|
import torch.utils._pytree as pytree
|
31
29
|
|
32
|
-
BATCH_SIZE = 1
|
33
|
-
|
34
30
|
|
35
31
|
@dataclasses.dataclass
|
36
32
|
class KVCacheEntryBase:
|
37
33
|
"""A single cache entry that includes K and V caches.
|
38
34
|
|
39
35
|
The chaches are built based on the provided config with the shape of
|
40
|
-
(batch_size
|
36
|
+
(batch_size, kv_cache_max, num_query_groups, head_dim).
|
41
37
|
"""
|
42
38
|
|
43
39
|
k_cache: torch.Tensor
|
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
|
|
46
42
|
@classmethod
|
47
43
|
def _from_model_config(
|
48
44
|
cls,
|
49
|
-
|
50
|
-
|
51
|
-
k_shape: Tuple,
|
52
|
-
v_shape: Tuple,
|
45
|
+
k_shape: Tuple[int, ...],
|
46
|
+
v_shape: Tuple[int, ...],
|
53
47
|
dtype: torch.dtype = torch.float32,
|
54
48
|
device: torch.device = None,
|
55
49
|
) -> "KVCacheEntryBase":
|
@@ -66,12 +60,11 @@ class KVCacheEntryBase:
|
|
66
60
|
config: model_config.AttentionConfig,
|
67
61
|
dtype: torch.dtype = torch.float32,
|
68
62
|
device: torch.device = None,
|
63
|
+
batch_size: int = 1,
|
69
64
|
) -> "KVCacheEntryBase":
|
70
65
|
"""Build an instance of the class based on model config."""
|
71
|
-
shape = (
|
72
|
-
return cls._from_model_config(
|
73
|
-
kv_cache_max, config, shape, shape, dtype, device
|
74
|
-
)
|
66
|
+
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
67
|
+
return cls._from_model_config(shape, shape, dtype, device)
|
75
68
|
|
76
69
|
|
77
70
|
@dataclasses.dataclass
|
@@ -93,24 +86,22 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
|
|
93
86
|
config: model_config.AttentionConfig,
|
94
87
|
dtype: torch.dtype = torch.float32,
|
95
88
|
device: torch.device = None,
|
89
|
+
batch_size: int = 1,
|
96
90
|
) -> "KVCacheEntryBase":
|
97
91
|
"""Build an instance of the class based on model config."""
|
98
|
-
num_kv_heads = config.num_query_groups
|
99
92
|
k_shape = (
|
100
|
-
|
101
|
-
|
93
|
+
batch_size,
|
94
|
+
config.num_query_groups,
|
102
95
|
kv_cache_max,
|
103
96
|
config.head_dim,
|
104
|
-
) #
|
97
|
+
) # b, k, s, h
|
105
98
|
v_shape = (
|
106
|
-
|
107
|
-
|
99
|
+
batch_size,
|
100
|
+
config.num_query_groups,
|
108
101
|
config.head_dim,
|
109
102
|
kv_cache_max,
|
110
|
-
) #
|
111
|
-
return cls._from_model_config(
|
112
|
-
kv_cache_max, config, k_shape, v_shape, dtype, device
|
113
|
-
)
|
103
|
+
) # b, k, h, s
|
104
|
+
return cls._from_model_config(k_shape, v_shape, dtype, device)
|
114
105
|
|
115
106
|
|
116
107
|
@dataclasses.dataclass
|
@@ -126,6 +117,7 @@ class KVCacheBase:
|
|
126
117
|
config: model_config.ModelConfig,
|
127
118
|
dtype: torch.dtype = torch.float32,
|
128
119
|
device: torch.device = None,
|
120
|
+
batch_size: int = 1,
|
129
121
|
) -> "KVCacheBase":
|
130
122
|
caches = [
|
131
123
|
kv_entry_cls.from_model_config(
|
@@ -133,6 +125,7 @@ class KVCacheBase:
|
|
133
125
|
config.block_config(idx).attn_config,
|
134
126
|
dtype,
|
135
127
|
device,
|
128
|
+
batch_size,
|
136
129
|
)
|
137
130
|
for idx in range(config.num_layers)
|
138
131
|
]
|
@@ -145,6 +138,7 @@ class KVCacheBase:
|
|
145
138
|
config: model_config.ModelConfig,
|
146
139
|
dtype: torch.dtype = torch.float32,
|
147
140
|
device: torch.device = None,
|
141
|
+
batch_size: int = 1,
|
148
142
|
) -> "KVCacheBase":
|
149
143
|
"""Build an instance of the class based on model config.
|
150
144
|
|
@@ -154,12 +148,19 @@ class KVCacheBase:
|
|
154
148
|
Defaults to torch.float32.
|
155
149
|
device (torch.device, optional): The device placement of the cache
|
156
150
|
tensors. Defaults to None.
|
151
|
+
batch_size (int, optional): The batch size of the cache tensors.
|
152
|
+
Defaults to 1.
|
157
153
|
|
158
154
|
Returns:
|
159
155
|
KVCacheBase: The created cache object.
|
160
156
|
"""
|
157
|
+
assert batch_size == 1, "Batch size must be 1 for KV Cache."
|
161
158
|
return cls._from_model_config(
|
162
|
-
KVCacheEntryBase,
|
159
|
+
KVCacheEntryBase,
|
160
|
+
config=config,
|
161
|
+
dtype=dtype,
|
162
|
+
device=device,
|
163
|
+
batch_size=batch_size,
|
163
164
|
)
|
164
165
|
|
165
166
|
def flatten(self) -> List[torch.Tensor]:
|
@@ -177,9 +178,14 @@ class KVCacheBTNH(KVCacheBase):
|
|
177
178
|
config: model_config.ModelConfig,
|
178
179
|
dtype: torch.dtype = torch.float32,
|
179
180
|
device: torch.device = None,
|
181
|
+
batch_size: int = 1,
|
180
182
|
) -> "KVCacheBTNH":
|
181
183
|
return cls._from_model_config(
|
182
|
-
KVCacheEntryBTNH,
|
184
|
+
KVCacheEntryBTNH,
|
185
|
+
config=config,
|
186
|
+
dtype=dtype,
|
187
|
+
device=device,
|
188
|
+
batch_size=batch_size,
|
183
189
|
)
|
184
190
|
|
185
191
|
|
@@ -192,9 +198,14 @@ class KVCacheTransposed(KVCacheBase):
|
|
192
198
|
config: model_config.ModelConfig,
|
193
199
|
dtype: torch.dtype = torch.float32,
|
194
200
|
device: torch.device = None,
|
201
|
+
batch_size: int = 1,
|
195
202
|
) -> "KVCacheBTNH":
|
196
203
|
return cls._from_model_config(
|
197
|
-
KVCacheEntryTransposed,
|
204
|
+
KVCacheEntryTransposed,
|
205
|
+
config=config,
|
206
|
+
dtype=dtype,
|
207
|
+
device=device,
|
208
|
+
batch_size=batch_size,
|
198
209
|
)
|
199
210
|
|
200
211
|
|
@@ -258,7 +269,6 @@ def update(
|
|
258
269
|
input_pos: torch.Tensor,
|
259
270
|
k_slice: torch.Tensor,
|
260
271
|
v_slice: torch.Tensor,
|
261
|
-
use_dus: bool = True,
|
262
272
|
) -> KVCacheEntryBase:
|
263
273
|
"""Out of place update of Cache buffer.
|
264
274
|
|
@@ -309,6 +319,10 @@ def _update_kv_impl(
|
|
309
319
|
positions = input_pos.clone()
|
310
320
|
k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
|
311
321
|
v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
|
312
|
-
k = dynamic_update_slice(
|
313
|
-
|
322
|
+
k = dus_utils.dynamic_update_slice(
|
323
|
+
cache.k_cache, k_slice, [x for x in k_slice_indices]
|
324
|
+
)
|
325
|
+
v = dus_utils.dynamic_update_slice(
|
326
|
+
cache.v_cache, v_slice, [x for x in v_slice_indices]
|
327
|
+
)
|
314
328
|
return KVCacheEntryTransposed(k, v)
|
@@ -27,6 +27,7 @@ from ai_edge_torch.generative.examples.paligemma import decoder2
|
|
27
27
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
28
28
|
from ai_edge_torch.generative.examples.phi import phi2
|
29
29
|
from ai_edge_torch.generative.examples.phi import phi3
|
30
|
+
from ai_edge_torch.generative.examples.phi import phi4
|
30
31
|
from ai_edge_torch.generative.examples.qwen import qwen
|
31
32
|
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
32
33
|
from ai_edge_torch.generative.examples.smollm import smollm
|
@@ -139,6 +140,15 @@ class TestModelConversion(googletest.TestCase):
|
|
139
140
|
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
140
141
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
141
142
|
|
143
|
+
@googletest.skipIf(
|
144
|
+
ai_edge_torch.config.in_oss,
|
145
|
+
reason="tests with custom ops are not supported in oss",
|
146
|
+
)
|
147
|
+
def test_phi4(self):
|
148
|
+
config = phi4.get_fake_model_config()
|
149
|
+
pytorch_model = phi4.Phi4Mini(config).eval()
|
150
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
151
|
+
|
142
152
|
@googletest.skipIf(
|
143
153
|
ai_edge_torch.config.in_oss,
|
144
154
|
reason="tests with custom ops are not supported in oss",
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250301
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=MENyVQGKk5h6YnKhfVQlzGJnWaGJrL8J86HAtU_LAQM,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -84,11 +84,14 @@ ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0Ye
|
|
84
84
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
85
85
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
86
86
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=CaI_-Vtd0j9FoWIDd8q5z4CFsGYUhTwEWGvMGaXICuU,2514
|
87
|
+
ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=hu_fMYqHU_bxE3DzE-sNj8YSrsFLmErnNRZOODVXZjE,2512
|
87
88
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=g-MvEibJT_iIhkec2VGtFFA_iP54VCq9mY4KxwAYF08,2512
|
88
89
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
|
89
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
90
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=ddo52Inl5ub81q460cEyKhnsC3txellRErut-_qtBbM,6949
|
91
|
+
ai_edge_torch/generative/examples/phi/phi4.py,sha256=OkMwLGe8l2JEAgOFi19AdbNBl1xp1djZBZo8MJP58ho,5732
|
90
92
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
91
93
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
94
|
+
ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
|
92
95
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
96
|
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehdak9-5DDisACs9VlTwr8eFwcjQ_kZxgc,2776
|
94
97
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
@@ -147,8 +150,8 @@ ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBiz
|
|
147
150
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
148
151
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
149
152
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
150
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=
|
151
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=
|
153
|
+
ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
|
154
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=VN4gn4ylaVOwaTR5EXKv0YTVgpQ850bmjGLCgCCI1ps,9267
|
152
155
|
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
|
153
156
|
ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
|
154
157
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -168,7 +171,7 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
|
|
168
171
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
169
172
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
170
173
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
171
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256
|
174
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3GyBn4k7BWVgyGzrbcL30Su3nxZYLtwkCs,14787
|
172
175
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
173
176
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
174
177
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -230,8 +233,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
230
233
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
231
234
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
232
235
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
233
|
-
ai_edge_torch_nightly-0.4.0.
|
234
|
-
ai_edge_torch_nightly-0.4.0.
|
235
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
+
ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
237
|
+
ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/METADATA,sha256=VbeGOSHuc6HIM269rYt6xGOlKC_Pr6_EDGFlCVXa7qg,1966
|
238
|
+
ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
239
|
+
ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
240
|
+
ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/RECORD,,
|
File without changes
|
File without changes
|