ai-edge-torch-nightly 0.4.0.dev20250227__py3-none-any.whl → 0.4.0.dev20250301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,80 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-4 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi4
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi4'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
+ '/tmp/',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi4',
40
+ 'The prefix of the output tflite model name.',
41
+ )
42
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
+ 'prefill_seq_lens',
44
+ (8, 64, 128, 256, 512, 1024),
45
+ 'List of the maximum sizes of prefill input tensors.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
62
+
63
+
64
+ def main(_):
65
+ pytorch_model = phi4.build_model(
66
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
67
+ )
68
+ converter.convert_to_tflite(
69
+ pytorch_model,
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
+ quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
75
+ export_config=ExportConfig(),
76
+ )
77
+
78
+
79
+ if __name__ == '__main__':
80
+ app.run(main)
@@ -136,10 +136,7 @@ def _build_phi3_rope(
136
136
 
137
137
  class Phi3_5Mini(model_builder.DecoderOnlyModel):
138
138
  """A Phi-3.5 model built from the Edge Generative API layers."""
139
-
140
- def __init__(self, config: cfg.ModelConfig):
141
- super().__init__(config)
142
- attn_config = self.config.block_config(0).attn_config
139
+ pass
143
140
 
144
141
 
145
142
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -150,7 +147,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
147
  is 1024.
151
148
 
152
149
  Returns:
153
- The model config for a Phi-2 model.
150
+ The model config for a Phi-3.5 model.
154
151
  """
155
152
  attn_config = cfg.AttentionConfig(
156
153
  num_heads=32,
@@ -0,0 +1,165 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Phi-4 model up to 4K tokens, not to 128K tokens."""
17
+
18
+ from functools import partial
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ from ai_edge_torch.generative.utilities import model_builder
24
+ import ai_edge_torch.generative.utilities.loader as loading_utils
25
+ import torch
26
+
27
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
+ ff_up_proj="model.layers.{}.mlp.gate_up_proj",
29
+ ff_down_proj="model.layers.{}.mlp.down_proj",
30
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
31
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
32
+ pre_attn_norm="model.layers.{}.input_layernorm",
33
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
34
+ embedding="model.embed_tokens",
35
+ final_norm="model.norm",
36
+ )
37
+
38
+ # max_position_embeddings / original_max_position_embeddings in Phi-4 config.
39
+ ROPE_SCALE_FACTOR = 32
40
+
41
+ # ROPE short factor in Phi-4 config. According to LOPE paper and its code in
42
+ # https://github.com/microsoft/LongRoPE, these values had been searched with
43
+ # min=1.0, step-0.01 to optimize the errors of sample dataset.
44
+ ROPE_SHORT_FACTOR = [1.0] * 48
45
+
46
+
47
+ def _build_phi4_rope(
48
+ input_pos: int,
49
+ n_elem: int,
50
+ base: int,
51
+ condense_ratio: int,
52
+ dtype: torch.dtype,
53
+ device: torch.device,
54
+ theta_factors: torch.Tensor,
55
+ scale: float,
56
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ """Computes Rotary Positional Embeddings for Phi-4 model.
58
+
59
+ It's a modified version of attn_utils.build_rope_cache with additional
60
+ arguments for Phi-4 model. It precompute Rotary Positional Embedding Sin and
61
+ Cos values with scaling factors for quick lookup during the inference.
62
+
63
+ Args:
64
+ input_pos (torch.Tensor): the given input sequence positions
65
+ n_elem (int): Each sequence's dimmension.
66
+ base (int, optional): Rope base value.
67
+ condense_ratio (int, optional): The ratio by which sequence indicies are
68
+ condensed.
69
+ dtype (torch.dtype, optional): Output tensor's data type.
70
+ device (torch.device, optional): Output tensor's data type.
71
+ theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
72
+ to scale the theta values.
73
+ scale (float, optional): A float used to scale the rope values.
74
+
75
+ Returns:
76
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
77
+ """
78
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
79
+ theta = theta / theta_factors
80
+ seq_idx = input_pos / condense_ratio
81
+ idx_theta = torch.outer(seq_idx, theta)
82
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
83
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
84
+ return cos, sin
85
+
86
+
87
+ class Phi4Mini(model_builder.DecoderOnlyModel):
88
+ """A Phi-4 model built from the Edge Generative API layers."""
89
+ pass
90
+
91
+
92
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93
+ """Returns the model config for a Phi-4 model.
94
+
95
+ Args:
96
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
97
+ is 1024.
98
+
99
+ Returns:
100
+ The model config for a Phi-4 model.
101
+ """
102
+ attn_config = cfg.AttentionConfig(
103
+ num_heads=24,
104
+ head_dim=128,
105
+ num_query_groups=8,
106
+ rotary_base=10000,
107
+ rotary_percentage=0.75,
108
+ qkv_transpose_before_split=True,
109
+ )
110
+ ff_config = cfg.FeedForwardConfig(
111
+ type=cfg.FeedForwardType.SEQUENTIAL,
112
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
113
+ intermediate_size=8192,
114
+ )
115
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
116
+ block_config = cfg.TransformerBlockConfig(
117
+ attn_config=attn_config,
118
+ ff_config=ff_config,
119
+ pre_attention_norm_config=norm_config,
120
+ post_attention_norm_config=norm_config,
121
+ )
122
+
123
+ max_seq_len = 4096
124
+ # Create the RoPE callable
125
+ build_rope = partial(
126
+ _build_phi4_rope,
127
+ condense_ratio=1,
128
+ dtype=torch.float32,
129
+ device=torch.device("cpu"),
130
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
131
+ scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
132
+ )
133
+
134
+ config = cfg.ModelConfig(
135
+ vocab_size=200064,
136
+ num_layers=32,
137
+ max_seq_len=max_seq_len,
138
+ kv_cache_max_len=kv_cache_max_len,
139
+ embedding_dim=3072,
140
+ block_configs=block_config,
141
+ final_norm_config=norm_config,
142
+ enable_hlfb=True,
143
+ build_rope=build_rope,
144
+ )
145
+ return config
146
+
147
+
148
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
149
+ config = get_model_config(kv_cache_max_len)
150
+ config.vocab_size = 128
151
+ config.num_layers = 2
152
+ config.max_seq_len = 2 * kv_cache_max_len
153
+ # Phi-4 has only one block config.
154
+ config.block_config(0).ff_config.intermediate_size = 128
155
+ return config
156
+
157
+
158
+ def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
159
+ """Instantiates the model instance and load checkpoint if provided."""
160
+ return model_builder.build_decoder_only_model(
161
+ checkpoint_path=checkpoint_path,
162
+ config=get_model_config(**kwargs),
163
+ tensor_names=TENSOR_NAMES,
164
+ model_class=Phi4Mini,
165
+ )
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-4 model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi4
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "Instruct: Write an email about the weather Output:",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "microsoft/Phi-4-mini-instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = phi4.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56
+
57
+ verifier.verify_reauthored_model(
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
+ generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
65
+ )
66
+
67
+
68
+ if __name__ == "__main__":
69
+ app.run(main)
@@ -52,7 +52,6 @@ class TransformerBlock(nn.Module):
52
52
  config.pre_attention_norm_config,
53
53
  )
54
54
  self.atten_func = CausalSelfAttention(
55
- model_config.batch_size,
56
55
  model_config.embedding_dim,
57
56
  config.attn_config,
58
57
  model_config.enable_hlfb,
@@ -119,7 +118,6 @@ class CausalSelfAttention(nn.Module):
119
118
 
120
119
  def __init__(
121
120
  self,
122
- batch_size: int,
123
121
  dim: int,
124
122
  config: cfg.AttentionConfig,
125
123
  enable_hlfb: bool,
@@ -127,14 +125,12 @@ class CausalSelfAttention(nn.Module):
127
125
  """Initialize an instance of CausalSelfAttention.
128
126
 
129
127
  Args:
130
- batch_size (int): batch size of the input tensor.
131
128
  dim (int): causal attention's input/output dimmension.
132
129
  config (cfg.AttentionConfig): attention specific configurations.
133
130
  enable_hlfb (bool): whether hlfb is enabled or not.
134
131
  """
135
132
  super().__init__()
136
133
  self.kv_cache = None
137
- self.batch_size = batch_size
138
134
  qkv_shape = (
139
135
  config.num_heads + 2 * config.num_query_groups
140
136
  ) * config.head_dim
@@ -180,10 +176,6 @@ class CausalSelfAttention(nn.Module):
180
176
  """
181
177
  # Batch size, sequence length, embedding dimensionality.
182
178
  B, T, E = x.size()
183
- assert B == self.batch_size, (
184
- "batch size of input tensor must match with the batch size specified in"
185
- " the model configuration."
186
- )
187
179
 
188
180
  qkv = self.qkv_projection(x)
189
181
 
@@ -21,23 +21,19 @@ This is an experimental implementation and is subject to change at any time.
21
21
  import dataclasses
22
22
  from typing import List, Tuple
23
23
 
24
- from ai_edge_torch import hlfb
25
24
  from ai_edge_torch.generative.layers import model_config
26
- from ai_edge_torch.generative.layers.experimental import types as types
27
- from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
25
+ from ai_edge_torch.generative.layers.experimental import types
26
+ from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
28
27
  import torch
29
- import torch.nn as nn
30
28
  import torch.utils._pytree as pytree
31
29
 
32
- BATCH_SIZE = 1
33
-
34
30
 
35
31
  @dataclasses.dataclass
36
32
  class KVCacheEntryBase:
37
33
  """A single cache entry that includes K and V caches.
38
34
 
39
35
  The chaches are built based on the provided config with the shape of
40
- (batch_size=1, kv_cache_max, num_query_groups, head_dim).
36
+ (batch_size, kv_cache_max, num_query_groups, head_dim).
41
37
  """
42
38
 
43
39
  k_cache: torch.Tensor
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
46
42
  @classmethod
47
43
  def _from_model_config(
48
44
  cls,
49
- kv_cache_max: int,
50
- config: model_config.AttentionConfig,
51
- k_shape: Tuple,
52
- v_shape: Tuple,
45
+ k_shape: Tuple[int, ...],
46
+ v_shape: Tuple[int, ...],
53
47
  dtype: torch.dtype = torch.float32,
54
48
  device: torch.device = None,
55
49
  ) -> "KVCacheEntryBase":
@@ -66,12 +60,11 @@ class KVCacheEntryBase:
66
60
  config: model_config.AttentionConfig,
67
61
  dtype: torch.dtype = torch.float32,
68
62
  device: torch.device = None,
63
+ batch_size: int = 1,
69
64
  ) -> "KVCacheEntryBase":
70
65
  """Build an instance of the class based on model config."""
71
- shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72
- return cls._from_model_config(
73
- kv_cache_max, config, shape, shape, dtype, device
74
- )
66
+ shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
67
+ return cls._from_model_config(shape, shape, dtype, device)
75
68
 
76
69
 
77
70
  @dataclasses.dataclass
@@ -93,24 +86,22 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
93
86
  config: model_config.AttentionConfig,
94
87
  dtype: torch.dtype = torch.float32,
95
88
  device: torch.device = None,
89
+ batch_size: int = 1,
96
90
  ) -> "KVCacheEntryBase":
97
91
  """Build an instance of the class based on model config."""
98
- num_kv_heads = config.num_query_groups
99
92
  k_shape = (
100
- 1,
101
- BATCH_SIZE * num_kv_heads,
93
+ batch_size,
94
+ config.num_query_groups,
102
95
  kv_cache_max,
103
96
  config.head_dim,
104
- ) # 1, bk, s, h
97
+ ) # b, k, s, h
105
98
  v_shape = (
106
- 1,
107
- BATCH_SIZE * num_kv_heads,
99
+ batch_size,
100
+ config.num_query_groups,
108
101
  config.head_dim,
109
102
  kv_cache_max,
110
- ) # 1, bk, h, s
111
- return cls._from_model_config(
112
- kv_cache_max, config, k_shape, v_shape, dtype, device
113
- )
103
+ ) # b, k, h, s
104
+ return cls._from_model_config(k_shape, v_shape, dtype, device)
114
105
 
115
106
 
116
107
  @dataclasses.dataclass
@@ -126,6 +117,7 @@ class KVCacheBase:
126
117
  config: model_config.ModelConfig,
127
118
  dtype: torch.dtype = torch.float32,
128
119
  device: torch.device = None,
120
+ batch_size: int = 1,
129
121
  ) -> "KVCacheBase":
130
122
  caches = [
131
123
  kv_entry_cls.from_model_config(
@@ -133,6 +125,7 @@ class KVCacheBase:
133
125
  config.block_config(idx).attn_config,
134
126
  dtype,
135
127
  device,
128
+ batch_size,
136
129
  )
137
130
  for idx in range(config.num_layers)
138
131
  ]
@@ -145,6 +138,7 @@ class KVCacheBase:
145
138
  config: model_config.ModelConfig,
146
139
  dtype: torch.dtype = torch.float32,
147
140
  device: torch.device = None,
141
+ batch_size: int = 1,
148
142
  ) -> "KVCacheBase":
149
143
  """Build an instance of the class based on model config.
150
144
 
@@ -154,12 +148,19 @@ class KVCacheBase:
154
148
  Defaults to torch.float32.
155
149
  device (torch.device, optional): The device placement of the cache
156
150
  tensors. Defaults to None.
151
+ batch_size (int, optional): The batch size of the cache tensors.
152
+ Defaults to 1.
157
153
 
158
154
  Returns:
159
155
  KVCacheBase: The created cache object.
160
156
  """
157
+ assert batch_size == 1, "Batch size must be 1 for KV Cache."
161
158
  return cls._from_model_config(
162
- KVCacheEntryBase, config=config, dtype=dtype, device=device
159
+ KVCacheEntryBase,
160
+ config=config,
161
+ dtype=dtype,
162
+ device=device,
163
+ batch_size=batch_size,
163
164
  )
164
165
 
165
166
  def flatten(self) -> List[torch.Tensor]:
@@ -177,9 +178,14 @@ class KVCacheBTNH(KVCacheBase):
177
178
  config: model_config.ModelConfig,
178
179
  dtype: torch.dtype = torch.float32,
179
180
  device: torch.device = None,
181
+ batch_size: int = 1,
180
182
  ) -> "KVCacheBTNH":
181
183
  return cls._from_model_config(
182
- KVCacheEntryBTNH, config=config, dtype=dtype, device=device
184
+ KVCacheEntryBTNH,
185
+ config=config,
186
+ dtype=dtype,
187
+ device=device,
188
+ batch_size=batch_size,
183
189
  )
184
190
 
185
191
 
@@ -192,9 +198,14 @@ class KVCacheTransposed(KVCacheBase):
192
198
  config: model_config.ModelConfig,
193
199
  dtype: torch.dtype = torch.float32,
194
200
  device: torch.device = None,
201
+ batch_size: int = 1,
195
202
  ) -> "KVCacheBTNH":
196
203
  return cls._from_model_config(
197
- KVCacheEntryTransposed, config=config, dtype=dtype, device=device
204
+ KVCacheEntryTransposed,
205
+ config=config,
206
+ dtype=dtype,
207
+ device=device,
208
+ batch_size=batch_size,
198
209
  )
199
210
 
200
211
 
@@ -258,7 +269,6 @@ def update(
258
269
  input_pos: torch.Tensor,
259
270
  k_slice: torch.Tensor,
260
271
  v_slice: torch.Tensor,
261
- use_dus: bool = True,
262
272
  ) -> KVCacheEntryBase:
263
273
  """Out of place update of Cache buffer.
264
274
 
@@ -309,6 +319,10 @@ def _update_kv_impl(
309
319
  positions = input_pos.clone()
310
320
  k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311
321
  v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312
- k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313
- v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
322
+ k = dus_utils.dynamic_update_slice(
323
+ cache.k_cache, k_slice, [x for x in k_slice_indices]
324
+ )
325
+ v = dus_utils.dynamic_update_slice(
326
+ cache.v_cache, v_slice, [x for x in v_slice_indices]
327
+ )
314
328
  return KVCacheEntryTransposed(k, v)
@@ -27,6 +27,7 @@ from ai_edge_torch.generative.examples.paligemma import decoder2
27
27
  from ai_edge_torch.generative.examples.paligemma import paligemma
28
28
  from ai_edge_torch.generative.examples.phi import phi2
29
29
  from ai_edge_torch.generative.examples.phi import phi3
30
+ from ai_edge_torch.generative.examples.phi import phi4
30
31
  from ai_edge_torch.generative.examples.qwen import qwen
31
32
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
32
33
  from ai_edge_torch.generative.examples.smollm import smollm
@@ -139,6 +140,15 @@ class TestModelConversion(googletest.TestCase):
139
140
  pytorch_model = phi3.Phi3_5Mini(config).eval()
140
141
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
141
142
 
143
+ @googletest.skipIf(
144
+ ai_edge_torch.config.in_oss,
145
+ reason="tests with custom ops are not supported in oss",
146
+ )
147
+ def test_phi4(self):
148
+ config = phi4.get_fake_model_config()
149
+ pytorch_model = phi4.Phi4Mini(config).eval()
150
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
151
+
142
152
  @googletest.skipIf(
143
153
  ai_edge_torch.config.in_oss,
144
154
  reason="tests with custom ops are not supported in oss",
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250227"
16
+ __version__ = "0.4.0.dev20250301"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250227
3
+ Version: 0.4.0.dev20250301
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=K2jtDrBNGi74j_uQYVUT6MJ2-aQFKkKy5ZYur9iWdVU,706
5
+ ai_edge_torch/version.py,sha256=MENyVQGKk5h6YnKhfVQlzGJnWaGJrL8J86HAtU_LAQM,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -84,11 +84,14 @@ ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0Ye
84
84
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
85
85
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
86
86
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=CaI_-Vtd0j9FoWIDd8q5z4CFsGYUhTwEWGvMGaXICuU,2514
87
+ ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=hu_fMYqHU_bxE3DzE-sNj8YSrsFLmErnNRZOODVXZjE,2512
87
88
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=g-MvEibJT_iIhkec2VGtFFA_iP54VCq9mY4KxwAYF08,2512
88
89
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
89
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=7pwHStGEPOuO1DnWiiavioRQXskMqbJMv3ctFNFuBU0,7075
90
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=ddo52Inl5ub81q460cEyKhnsC3txellRErut-_qtBbM,6949
91
+ ai_edge_torch/generative/examples/phi/phi4.py,sha256=OkMwLGe8l2JEAgOFi19AdbNBl1xp1djZBZo8MJP58ho,5732
90
92
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
91
93
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
94
+ ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
92
95
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
96
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehdak9-5DDisACs9VlTwr8eFwcjQ_kZxgc,2776
94
97
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
@@ -147,8 +150,8 @@ ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBiz
147
150
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
148
151
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
149
152
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
150
- ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
151
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
153
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
154
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=VN4gn4ylaVOwaTR5EXKv0YTVgpQ850bmjGLCgCCI1ps,9267
152
155
  ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
153
156
  ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
154
157
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -168,7 +171,7 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
168
171
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
169
172
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
170
173
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
171
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bXJwDxSPgxVKp-_6BsEmMA3TuMUaUNiZoYomNounxco,14416
174
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3GyBn4k7BWVgyGzrbcL30Su3nxZYLtwkCs,14787
172
175
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
173
176
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
174
177
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -230,8 +233,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
230
233
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
231
234
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
232
235
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
233
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/METADATA,sha256=cHcz3adq1WwVddazAJ06h7SKITJm70eMpFVjoNa2Jw4,1966
235
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/RECORD,,
236
+ ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
237
+ ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/METADATA,sha256=VbeGOSHuc6HIM269rYt6xGOlKC_Pr6_EDGFlCVXa7qg,1966
238
+ ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
239
+ ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
240
+ ai_edge_torch_nightly-0.4.0.dev20250301.dist-info/RECORD,,