ai-edge-torch-nightly 0.6.0.dev20250521__py3-none-any.whl → 0.6.0.dev20250523__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,61 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Qwen 3.0 models to multi-signature tflite model."""
17
+
18
+ from absl import app
19
+ from ai_edge_torch.generative.examples.qwen import qwen3
20
+ from ai_edge_torch.generative.utilities import converter
21
+ from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
23
+
24
+ flags = converter.define_conversion_flags('qwen')
25
+
26
+ _MODEL_SIZE = flags.DEFINE_enum(
27
+ 'model_size',
28
+ '1.7b',
29
+ ['0.6b', '1.7b', '4b'],
30
+ 'The size of the model to convert.',
31
+ )
32
+
33
+ _BUILDER = {
34
+ '0.6b': qwen3.build_0_6b_model,
35
+ '1.7b': qwen3.build_1_7b_model,
36
+ '4b': qwen3.build_4b_model,
37
+ }
38
+
39
+
40
+ def main(_):
41
+ checkpoint_path = flags.FLAGS.checkpoint_path
42
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
43
+ checkpoint_path,
44
+ custom_loader=loader.maybe_get_custom_loader(
45
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
+ ),
47
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
48
+ )
49
+ converter.convert_to_tflite(
50
+ pytorch_model,
51
+ output_path=flags.FLAGS.output_path,
52
+ output_name_prefix=flags.FLAGS.output_name_prefix,
53
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
54
+ quantize=flags.FLAGS.quantize,
55
+ lora_ranks=flags.FLAGS.lora_ranks,
56
+ export_config=export_config.get_from_flags(),
57
+ )
58
+
59
+
60
+ if __name__ == '__main__':
61
+ app.run(main)
@@ -0,0 +1,171 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Qwen 3.0 models."""
17
+
18
+ from typing import Callable, Dict
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ from ai_edge_torch.generative.utilities import loader as loading_utils
21
+ from ai_edge_torch.generative.utilities import model_builder
22
+ import torch
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ ff_up_proj="model.layers.{}.mlp.up_proj",
27
+ ff_down_proj="model.layers.{}.mlp.down_proj",
28
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
29
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
30
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
31
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
32
+ attn_query_norm="model.layers.{}.self_attn.q_norm",
33
+ attn_key_norm="model.layers.{}.self_attn.k_norm",
34
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
35
+ pre_attn_norm="model.layers.{}.input_layernorm",
36
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
37
+ embedding="model.embed_tokens",
38
+ final_norm="model.norm",
39
+ lm_head="lm_head",
40
+ )
41
+
42
+
43
+ class Qwen3(model_builder.DecoderOnlyModel):
44
+ """A Qwen3 model built from the Edge Generative API layers."""
45
+
46
+ pass
47
+
48
+
49
+ def get_4b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
50
+ """Returns the model config for a Qwen 3.0 4B model.
51
+
52
+ Args:
53
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
54
+ is 1024.
55
+
56
+ Returns:
57
+ The model config for a SmolLM model.
58
+ """
59
+ norm_config = cfg.NormalizationConfig(
60
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
61
+ )
62
+ attn_config = cfg.AttentionConfig(
63
+ num_heads=32,
64
+ head_dim=128,
65
+ num_query_groups=8,
66
+ query_norm_config=norm_config,
67
+ key_norm_config=norm_config,
68
+ rotary_base=1000000,
69
+ rotary_percentage=1.0,
70
+ qkv_use_bias=False,
71
+ qkv_transpose_before_split=True,
72
+ qkv_fused_interleaved=False, # No interleaved qkv projection.
73
+ )
74
+ ff_config = cfg.FeedForwardConfig(
75
+ type=cfg.FeedForwardType.GATED,
76
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
77
+ intermediate_size=9728,
78
+ )
79
+ block_config = cfg.TransformerBlockConfig(
80
+ attn_config=attn_config,
81
+ ff_config=ff_config,
82
+ pre_attention_norm_config=norm_config,
83
+ post_attention_norm_config=norm_config,
84
+ )
85
+ config = cfg.ModelConfig(
86
+ vocab_size=151936,
87
+ num_layers=36,
88
+ max_seq_len=40960,
89
+ embedding_dim=2560,
90
+ kv_cache_max_len=kv_cache_max_len,
91
+ block_configs=block_config,
92
+ final_norm_config=norm_config,
93
+ )
94
+ return config
95
+
96
+
97
+ def get_1_7b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
98
+ """Returns the model config for a Qwen 3.0 1.7B model."""
99
+ config = get_4b_model_config(kv_cache_max_len)
100
+ # Qwen has only one block config.
101
+ block_config = config.block_config(0)
102
+ block_config.attn_config.num_heads = 16
103
+ block_config.attn_config.head_dim = 128
104
+ block_config.ff_config.intermediate_size = 6144
105
+ config.num_layers = 28
106
+ config.embedding_dim = 2048
107
+ return config
108
+
109
+
110
+ def get_0_6b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
111
+ """Returns the model config for a Qwen 3.0 0.6B model."""
112
+ config = get_4b_model_config(kv_cache_max_len)
113
+ # Qwen has only one block config.
114
+ block_config = config.block_config(0)
115
+ block_config.attn_config.num_heads = 16
116
+ block_config.attn_config.head_dim = 128
117
+ block_config.ff_config.intermediate_size = 3072
118
+ config.num_layers = 28
119
+ config.embedding_dim = 1024
120
+ return config
121
+
122
+
123
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
124
+ config = get_4b_model_config(**kwargs)
125
+ config.vocab_size = 128
126
+ config.num_layers = 2
127
+ # Qwen has only one block config.
128
+ config.block_config(0).ff_config.intermediate_size = 64
129
+ return config
130
+
131
+
132
+ def build_4b_model(
133
+ checkpoint_path: str,
134
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
135
+ **kwargs
136
+ ) -> nn.Module:
137
+ return model_builder.build_decoder_only_model(
138
+ checkpoint_path=checkpoint_path,
139
+ config=get_4b_model_config(**kwargs),
140
+ tensor_names=TENSOR_NAMES,
141
+ model_class=Qwen3,
142
+ custom_loader=custom_loader,
143
+ )
144
+
145
+
146
+ def build_1_7b_model(
147
+ checkpoint_path: str,
148
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
149
+ **kwargs
150
+ ) -> nn.Module:
151
+ return model_builder.build_decoder_only_model(
152
+ checkpoint_path=checkpoint_path,
153
+ config=get_1_7b_model_config(**kwargs),
154
+ tensor_names=TENSOR_NAMES,
155
+ model_class=Qwen3,
156
+ custom_loader=custom_loader,
157
+ )
158
+
159
+
160
+ def build_0_6b_model(
161
+ checkpoint_path: str,
162
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
163
+ **kwargs
164
+ ) -> nn.Module:
165
+ return model_builder.build_decoder_only_model(
166
+ checkpoint_path=checkpoint_path,
167
+ config=get_0_6b_model_config(**kwargs),
168
+ tensor_names=TENSOR_NAMES,
169
+ model_class=Qwen3,
170
+ custom_loader=custom_loader,
171
+ )
@@ -48,6 +48,7 @@ _CHECKPOINT = {
48
48
  def main(_):
49
49
  verify_util.verify_qwen(
50
50
  model_size=_MODEL_SIZE.value,
51
+ model_version="v2",
51
52
  checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
52
53
  max_new_tokens=_MAX_NEW_TOKENS.value,
53
54
  prompts=_PROMPTS.value,
@@ -0,0 +1,59 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Qwen 3.0 0.6B, 1.7B, and 4B models."""
17
+
18
+
19
+ from absl import app
20
+ from absl import flags
21
+ from ai_edge_torch.generative.examples.qwen import verify_util
22
+
23
+
24
+ _MODEL_SIZE = flags.DEFINE_enum(
25
+ "model_size",
26
+ "0.6b",
27
+ ["0.6b", "1.7b", "4b"],
28
+ "The size of the model to verify.",
29
+ )
30
+ _PROMPTS = flags.DEFINE_multi_string(
31
+ "prompts",
32
+ "What is the meaning of life?",
33
+ "The input prompts to generate answers.",
34
+ )
35
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
36
+ "max_new_tokens",
37
+ 30,
38
+ "The maximum size of the generated tokens.",
39
+ )
40
+
41
+ _CHECKPOINT = {
42
+ "0.6b": "Qwen/Qwen3-0.6B",
43
+ "1.7b": "Qwen/Qwen3-1.7B",
44
+ "4b": "Qwen/Qwen3-4B",
45
+ }
46
+
47
+
48
+ def main(_):
49
+ verify_util.verify_qwen(
50
+ model_size=_MODEL_SIZE.value,
51
+ model_version="v3",
52
+ checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
53
+ max_new_tokens=_MAX_NEW_TOKENS.value,
54
+ prompts=_PROMPTS.value,
55
+ )
56
+
57
+
58
+ if __name__ == "__main__":
59
+ app.run(main)
@@ -17,24 +17,36 @@ import logging
17
17
  import os
18
18
  import pathlib
19
19
 
20
- from ai_edge_torch.generative.examples.qwen import qwen
20
+ from ai_edge_torch.generative.examples.qwen import qwen, qwen3
21
21
  from ai_edge_torch.generative.utilities import loader
22
22
  from ai_edge_torch.generative.utilities import transformers_verifier
23
23
  from ai_edge_torch.generative.utilities import verifier
24
24
  import transformers
25
25
 
26
26
 
27
- _BUILDER = {
27
+ _BUILDER_V2 = {
28
28
  "0.5b": qwen.build_0_5b_model,
29
29
  "1.5b": qwen.build_1_5b_model,
30
30
  "3b": qwen.build_3b_model,
31
31
  }
32
32
 
33
+ _BUILDER_V3 = {
34
+ "0.6b": qwen3.build_0_6b_model,
35
+ "1.7b": qwen3.build_1_7b_model,
36
+ "4b": qwen3.build_4b_model,
37
+ }
38
+
39
+ _BUILDER = {
40
+ "v2": _BUILDER_V2,
41
+ "v3": _BUILDER_V3,
42
+ }
43
+
33
44
  DEFAULT_PROMPTS = ["What is the meaning of life?"]
34
45
 
35
46
 
36
47
  def verify_qwen(
37
48
  model_size: str,
49
+ model_version: str,
38
50
  checkpoint_dir: str,
39
51
  weight_filename: str = "model.safetensors",
40
52
  max_new_tokens: int = 30,
@@ -64,7 +76,7 @@ def verify_qwen(
64
76
  reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
65
77
 
66
78
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
- reauthored_model = _BUILDER[model_size](
79
+ reauthored_model = _BUILDER[model_version][model_size](
68
80
  checkpoint_path=reauthored_checkpoint,
69
81
  custom_loader=custom_loader,
70
82
  )
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Common building blocks for Attention layer."""
17
17
 
18
+ import abc
18
19
  from typing import Optional, Tuple, Union
19
20
 
20
21
  from ai_edge_torch.generative.layers import builder
@@ -111,7 +112,42 @@ class TransformerBlock(nn.Module):
111
112
  return output if kv is None else (output, kv)
112
113
 
113
114
 
114
- class CausalSelfAttention(nn.Module):
115
+ class CausalSelfAttentionBase(nn.Module):
116
+ """Base class for causal self attention layer."""
117
+
118
+ def __init__(
119
+ self, dim: int, config: cfg.AttentionConfig, enable_hlfb: bool
120
+ ) -> None:
121
+ super().__init__()
122
+ self.dim = dim
123
+ self.config = config
124
+ self.enable_hlfb = enable_hlfb
125
+
126
+ self.query_norm = builder.build_norm(
127
+ self.config.head_dim, self.config.query_norm_config
128
+ )
129
+ self.key_norm = builder.build_norm(
130
+ self.config.head_dim, self.config.key_norm_config
131
+ )
132
+ self.value_norm = builder.build_norm(
133
+ self.config.head_dim, self.config.value_norm_config
134
+ )
135
+
136
+ @abc.abstractmethod
137
+ def forward(
138
+ self,
139
+ x: torch.Tensor,
140
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
141
+ mask: Optional[torch.Tensor] = None,
142
+ input_pos: Optional[torch.Tensor] = None,
143
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
144
+ lora: Optional[lora_utils.LoRAEntry] = None,
145
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
146
+ raise NotImplementedError()
147
+
148
+
149
+ class CausalSelfAttention(CausalSelfAttentionBase):
150
+ """Causal self attention layer implementation."""
115
151
 
116
152
  def __init__(
117
153
  self,
@@ -126,7 +162,7 @@ class CausalSelfAttention(nn.Module):
126
162
  config (cfg.AttentionConfig): attention specific configurations.
127
163
  enable_hlfb (bool): whether hlfb is enabled or not.
128
164
  """
129
- super().__init__()
165
+ super().__init__(dim, config, enable_hlfb)
130
166
  self.kv_cache = None
131
167
  qkv_shape = (
132
168
  config.num_heads + 2 * config.num_query_groups
@@ -137,12 +173,6 @@ class CausalSelfAttention(nn.Module):
137
173
  self.output_projection = nn.Linear(
138
174
  output_shape, dim, bias=config.output_proj_use_bias
139
175
  )
140
- self.query_norm = builder.build_norm(
141
- config.head_dim, config.query_norm_config
142
- )
143
- self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
144
- self.config = config
145
- self.enable_hlfb = enable_hlfb
146
176
 
147
177
  def forward(
148
178
  self,
@@ -204,6 +234,7 @@ class CausalSelfAttention(nn.Module):
204
234
 
205
235
  q = self.query_norm(q)
206
236
  k = self.key_norm(k)
237
+ v = self.value_norm(v)
207
238
 
208
239
  q = q.reshape(B, T, -1, self.config.head_dim)
209
240
  k = k.reshape(B, T, -1, self.config.head_dim)
@@ -15,9 +15,9 @@
15
15
  # Builder class for individual components.
16
16
  from typing import Callable
17
17
 
18
+ from ai_edge_torch.generative.layers import normalization
18
19
  import ai_edge_torch.generative.layers.feed_forward as feed_forward
19
20
  import ai_edge_torch.generative.layers.model_config as cfg
20
- import ai_edge_torch.generative.layers.normalization as normalization
21
21
  import torch
22
22
  from torch import nn
23
23
  import torch.nn.functional as F
@@ -74,6 +74,8 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
74
74
  dim,
75
75
  eps=config.epsilon,
76
76
  zero_centered_gamma=config.zero_centered,
77
+ with_scale=config.with_scale,
78
+ scale_shift=config.scale_shift,
77
79
  enable_hlfb=config.enable_hlfb,
78
80
  )
79
81
  elif config.type == cfg.NormalizationType.LAYER_NORM:
@@ -107,20 +109,13 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
107
109
  else:
108
110
  raise ValueError("Unsupported feedforward type.")
109
111
 
110
- activation = get_activation(config.activation)
111
-
112
112
  pre_ff_norm = build_norm(dim, config.pre_ff_norm_config)
113
113
  post_ff_norm = build_norm(dim, config.post_ff_norm_config)
114
114
 
115
115
  return ff_module(
116
116
  dim=dim,
117
- hidden_dim=config.intermediate_size,
118
- activation=activation,
119
- use_bias=config.use_bias,
120
- use_glu=(
121
- config.activation.type == cfg.ActivationType.GE_GLU
122
- or config.activation.type == cfg.ActivationType.SILU_GLU
123
- ),
117
+ activation=get_activation(config.activation),
118
+ config=config,
124
119
  pre_ff_norm=pre_ff_norm,
125
120
  post_ff_norm=post_ff_norm,
126
121
  )
@@ -14,45 +14,69 @@
14
14
  # ==============================================================================
15
15
  # Common building blocks for FeedForward layers.
16
16
 
17
- from typing import Callable, Optional
17
+ import abc
18
+ from typing import Callable
18
19
 
20
+ import ai_edge_torch.generative.layers.model_config as cfg
19
21
  import torch
20
22
  from torch import nn
21
23
 
22
24
 
23
- class SequentialFeedForward(nn.Module):
25
+ class FeedForwardBase(nn.Module):
26
+ """Base class for feedforward layer."""
27
+
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ activation: Callable[[torch.Tensor], torch.Tensor],
32
+ config: cfg.FeedForwardConfig,
33
+ pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
34
+ post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
35
+ ):
36
+ super().__init__()
37
+ self.dim = dim
38
+ self.act = activation
39
+ self.config = config
40
+ self.hidden_dim = config.intermediate_size
41
+ self.use_bias = config.use_bias
42
+ self.use_glu = (
43
+ config.activation.type == cfg.ActivationType.GE_GLU
44
+ or config.activation.type == cfg.ActivationType.SILU_GLU
45
+ )
46
+ self.pre_ff_norm = pre_ff_norm
47
+ self.post_ff_norm = post_ff_norm
48
+
49
+ @abc.abstractmethod
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ raise NotImplementedError()
52
+
53
+
54
+ class SequentialFeedForward(FeedForwardBase):
24
55
  """Vanilla sequential Feedforward with customizable activation."""
25
56
 
26
57
  def __init__(
27
58
  self,
28
59
  dim: int,
29
- hidden_dim: int,
30
60
  activation: Callable[[torch.Tensor], torch.Tensor],
31
- use_bias=False,
32
- use_glu=False,
33
- pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
34
- post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
61
+ config: cfg.FeedForwardConfig,
62
+ pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
63
+ post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
35
64
  ):
36
65
  """Init function for feedforward layer.
37
66
 
38
67
  Args:
39
68
  dim (int): embedding size.
40
- hidden_dim (int): hidden dim size of the feedforward layer.
41
69
  activation (Callable): activation function used in this block.
42
- use_bias (Boolean): whether to use bias. Default is false.
43
- use_glu (Boolean): whether to use glu in activation. Default is false.
44
- pre_ff_norm (Callable): pre feedforward norm. Default is None.
45
- post_ff_norm (Callable): post feedforward norm. Default is None.
70
+ config (cfg.FeedForwardConfig): feedforward layer configuration.
71
+ pre_ff_norm (Callable): pre feedforward norm. Default is identity.
72
+ post_ff_norm (Callable): post feedforward norm. Default is identity.
46
73
  """
47
- super().__init__()
48
- self.act = activation
49
- if use_glu:
50
- self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
74
+ super().__init__(dim, activation, config, pre_ff_norm, post_ff_norm)
75
+ if self.use_glu:
76
+ self.w1 = nn.Linear(dim, self.hidden_dim * 2, bias=self.use_bias)
51
77
  else:
52
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
53
- self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
54
- self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
55
- self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
78
+ self.w1 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
79
+ self.w2 = nn.Linear(self.hidden_dim, dim, bias=self.use_bias)
56
80
 
57
81
  def forward(self, x):
58
82
  """Forward pass for Feedforward layer.
@@ -68,7 +92,7 @@ class SequentialFeedForward(nn.Module):
68
92
  return self.post_ff_norm(out)
69
93
 
70
94
 
71
- class GatedFeedForward(nn.Module):
95
+ class GatedFeedForward(FeedForwardBase):
72
96
  """Gated Feedforward with customizable activation.
73
97
 
74
98
  https://arxiv.org/pdf/2002.05202v1.pdf
@@ -77,34 +101,48 @@ class GatedFeedForward(nn.Module):
77
101
  def __init__(
78
102
  self,
79
103
  dim: int,
80
- hidden_dim: int,
81
104
  activation: Callable[[torch.Tensor], torch.Tensor],
82
- use_bias=False,
83
- use_glu=False,
84
- pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
85
- post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
105
+ config: cfg.FeedForwardConfig,
106
+ pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
107
+ post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
86
108
  ):
87
109
  """Init function for feedforward layer.
88
110
 
89
111
  Args:
90
112
  dim (int): embedding size.
91
- hidden_dim (int): hidden dim size of the feedforward layer.
92
113
  activation (Callable): activation function used in this block.
93
- use_bias (Boolean): whether to use bias. Default is false.
94
- use_glu (Boolean): whether to use glu in activation. Default is false.
95
- pre_ff_norm (Callable): pre feedforward norm. Default is None.
96
- post_ff_norm (Callable): post feedforward norm. Default is None.
114
+ pre_ff_norm (Callable): pre feedforward norm. Default is identity.
115
+ post_ff_norm (Callable): post feedforward norm. Default is identity.
116
+ config (cfg.FeedForwardConfig): feedforward layer configuration.
97
117
  """
98
- super().__init__()
99
- self.act = activation
100
- if use_glu:
101
- self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
118
+ super().__init__(dim, activation, config, pre_ff_norm, post_ff_norm)
119
+
120
+ if self.use_glu:
121
+ assert (
122
+ self.config.use_separate_gating
123
+ ), 'use_separate_gating must be True for GE_GLU | SILU_GLU activation.'
124
+
125
+ if self.config.use_separate_gating:
126
+ if self.use_glu:
127
+ self.w1 = nn.Linear(dim, self.hidden_dim * 2, bias=self.use_bias)
128
+ else:
129
+ self.w1 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
130
+ self.w3 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
102
131
  else:
103
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
104
- self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
105
- self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
106
- self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
107
- self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
132
+ self.w_gating = nn.Parameter(
133
+ torch.ones((2, dim, self.hidden_dim), dtype=torch.float32),
134
+ requires_grad=False,
135
+ )
136
+ self.gating_bias = (
137
+ nn.Parameter(
138
+ torch.zeros((2, self.hidden_dim), dtype=torch.float32),
139
+ requires_grad=False,
140
+ )
141
+ if self.use_bias
142
+ else torch.zeros((2, self.hidden_dim), dtype=torch.float32)
143
+ )
144
+
145
+ self.w2 = nn.Linear(self.hidden_dim, dim, bias=self.use_bias)
108
146
 
109
147
  def forward(self, x):
110
148
  """Forward pass for Feedforward layer.
@@ -116,5 +154,12 @@ class GatedFeedForward(nn.Module):
116
154
  torch.Tensor: output tensor after feedforward.
117
155
  """
118
156
  x_norm = self.pre_ff_norm(x)
119
- out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
157
+ if self.config.use_separate_gating:
158
+ out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
159
+ else:
160
+ out = self.w2(
161
+ self.act(torch.matmul(x_norm, self.w_gating[0]) + self.gating_bias[0])
162
+ * (torch.matmul(x_norm, self.w_gating[1]) + self.gating_bias[1])
163
+ )
164
+
120
165
  return self.post_ff_norm(out)
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from ai_edge_torch.generative.layers import feed_forward
17
+ from ai_edge_torch.generative.layers import model_config as cfg
17
18
  import torch
18
19
  import torch.nn.functional as F
19
20
  from absl.testing import absltest as googletest
@@ -22,28 +23,32 @@ from absl.testing import absltest as googletest
22
23
  class FeedForwardTest(googletest.TestCase):
23
24
 
24
25
  def test_sequential_feed_forward(self):
26
+ ff_config = cfg.FeedForwardConfig(
27
+ type=cfg.FeedForwardType.SEQUENTIAL,
28
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
29
+ intermediate_size=10,
30
+ use_bias=True,
31
+ )
25
32
  ff = feed_forward.SequentialFeedForward(
26
33
  dim=10,
27
- hidden_dim=10,
28
34
  activation=F.silu,
29
- use_bias=True,
30
- use_glu=False,
31
- pre_ff_norm=torch.nn.Identity(),
32
- post_ff_norm=torch.nn.Identity(),
35
+ config=ff_config,
33
36
  )
34
37
  x = torch.ones((1, 10))
35
38
  out = ff(x)
36
39
  self.assertEqual(out.shape, (1, 10))
37
40
 
38
41
  def test_gated_feed_forward(self):
42
+ ff_config = cfg.FeedForwardConfig(
43
+ type=cfg.FeedForwardType.GATED,
44
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
45
+ intermediate_size=10,
46
+ use_bias=True,
47
+ )
39
48
  ff = feed_forward.GatedFeedForward(
40
49
  dim=10,
41
- hidden_dim=10,
42
50
  activation=F.silu,
43
- use_bias=True,
44
- use_glu=False,
45
- pre_ff_norm=torch.nn.Identity(),
46
- post_ff_norm=torch.nn.Identity(),
51
+ config=ff_config,
47
52
  )
48
53
  x = torch.ones((1, 10))
49
54
  out = ff(x)
@@ -69,10 +69,32 @@ class NormalizationConfig:
69
69
  enable_hlfb: bool = True
70
70
  epsilon: float = 1e-5
71
71
  zero_centered: bool = False
72
+ # Whether to use a scale parameter in the normalization.
73
+ with_scale: bool = False
74
+ # The shift to apply to the scale parameter.
75
+ scale_shift: float = 0.0
72
76
  # Number of groups used in group normalization.
73
77
  group_num: Optional[float] = None
74
78
 
75
79
 
80
+ # Exprimental feature and may subject to change.
81
+ class KVCacheUpdateStrategy(enum.Enum):
82
+ """Different alignment strategies of the KV cache.
83
+
84
+ Due to restrictions from different devices, we may need to apply different
85
+ alignment strategies to the KV cache during Attention layer's cache update.
86
+
87
+ Available options:
88
+ INPLACE: Update the existing cache in place using indexes.
89
+ PREPEND_LEFT: Append the new kv to the left of the existing cache. When this
90
+ cache update is applied, the newer kvs will always be prepended at the
91
+ beginning of the cache.
92
+ """
93
+
94
+ INPLACE = enum.auto()
95
+ PREPEND_LEFT = enum.auto()
96
+
97
+
76
98
  @dataclasses.dataclass
77
99
  class AttentionConfig:
78
100
  """Attention model's parameters."""
@@ -108,6 +130,12 @@ class AttentionConfig:
108
130
  key_norm_config: NormalizationConfig = dataclasses.field(
109
131
  default_factory=NormalizationConfig
110
132
  )
133
+ # The normalization applied to value projection's output.
134
+ value_norm_config: NormalizationConfig = dataclasses.field(
135
+ default_factory=NormalizationConfig
136
+ )
137
+ # Whether the KV cache is shared with the previous attention block.
138
+ kv_shared: bool = False
111
139
  relative_attention_num_buckets: int = 0
112
140
  relative_attention_max_distance: int = 0
113
141
  # Softcap on the output logits.
@@ -118,6 +146,8 @@ class AttentionConfig:
118
146
  sliding_window_size: Optional[int] = None
119
147
  # The default causal mask value used by attention layer.
120
148
  causal_mask_value: float = float("-inf")
149
+ # The update strategy of the KV cache. Default to INPLACE.
150
+ kvcache_update_strategy: KVCacheUpdateStrategy = KVCacheUpdateStrategy.INPLACE
121
151
 
122
152
 
123
153
  @dataclasses.dataclass
@@ -135,6 +165,9 @@ class FeedForwardConfig:
135
165
  type: FeedForwardType
136
166
  activation: ActivationConfig
137
167
  intermediate_size: int
168
+ # Whether to use two separate gating parameters or a single one in
169
+ # GatedFeedForward.
170
+ use_separate_gating: bool = True
138
171
  use_bias: bool = False
139
172
  # The normalization applied to feed forward's input.
140
173
  pre_ff_norm_config: NormalizationConfig = dataclasses.field(
@@ -15,8 +15,6 @@
15
15
 
16
16
  """Utilities for the models predefined in HuggingFace transformers."""
17
17
 
18
- from typing import cast
19
-
20
18
  from ai_edge_torch.generative.utilities import verifier
21
19
  import torch
22
20
  import transformers
@@ -39,4 +37,8 @@ class TransformersModelWrapper(verifier.ModelWrapper):
39
37
  self, inputs: torch.Tensor, max_new_tokens: int
40
38
  ) -> torch.IntTensor:
41
39
  gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
- return self.model.generate(inputs=inputs, generation_config=gen_config)
40
+ # Do not override GenerationConfig with model defaults. Always keep greedy
41
+ # sampling.
42
+ return self.model.generate(
43
+ inputs=inputs, generation_config=gen_config, use_model_defaults=False
44
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.6.0.dev20250521"
16
+ __version__ = "0.6.0.dev20250523"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250521
3
+ Version: 0.6.0.dev20250523
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=lmyCstaeVZjTAbBP4s9Z02tpX00ynyLPsymBY2tCe4A,706
5
+ ai_edge_torch/version.py,sha256=GtHv2onfhAfdaCEqSWpqO8k8_lxn7A37AJJnPbucqbI,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -118,9 +118,12 @@ ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=2MlgQrfRkhE7Dya8MIix
118
118
  ai_edge_torch/generative/examples/phi/verify_util.py,sha256=kRREOMSikn_BRbTDkQiXBllPZwmWHa9KUk-kK5lCkbU,2945
119
119
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
120
120
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=TnzyARHQgmWeOdYsV9WpRj5vhKGBH0kAbp3tMj8ZCYw,1998
121
+ ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py,sha256=GVV8CVj3rdgt_ZTOlpLSa6AD1pMMpMnZEuowzN2AIGM,2004
121
122
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=EcIHVeBcJLc290TiPkPfE7jdG_VXZYKlVGf0XQXzqo8,4554
122
- ai_edge_torch/generative/examples/qwen/verify.py,sha256=mP1SIAX2B1vFO02vRkAZC0UCyvBBxeWxK_456gG5a1s,1633
123
- ai_edge_torch/generative/examples/qwen/verify_util.py,sha256=jEmqYnOkOcQhOmHJrHsX0vdLq7JSahROvEBrG6n7tqg,2919
123
+ ai_edge_torch/generative/examples/qwen/qwen3.py,sha256=g6aVHjnlPo4YhLjSdXxONaDcKT3fZOh8cewlvf3cfoQ,5554
124
+ ai_edge_torch/generative/examples/qwen/verify_qwen2.py,sha256=ry-c2QesH-0KnrSQygfjUFs6d4kOFvJz2ts_8mP156I,1659
125
+ ai_edge_torch/generative/examples/qwen/verify_qwen3.py,sha256=hmE0gdyzgcDpEDcWiwOzKQcxt4XeAe9DPRspy_I-lc8,1628
126
+ ai_edge_torch/generative/examples/qwen/verify_util.py,sha256=vPROwLRABTChMGo5yWJkZURXP6TKWgh5FJj1Z3Zs6HU,3153
124
127
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
125
128
  ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=BM-ed7KrmPwzI3MvDs2R7P-kJgE1SK_cNVqIfXhtJjs,2411
126
129
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=plOi-3LltxReW_HVxhxwee_rYCQq-gsOwbGZtRsM8N8,4443
@@ -166,18 +169,18 @@ ai_edge_torch/generative/examples/tiny_llama/verify_util.py,sha256=_zYGqP4HO_Stc
166
169
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
167
170
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
168
171
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
169
- ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWnQ8npEAfgcjMIkEY,12964
172
+ ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6imHQ-aCXtmPXCh_o,13911
170
173
  ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
171
174
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
172
175
  ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
173
- ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
176
+ ai_edge_torch/generative/layers/builder.py,sha256=2bUgkyowDkDznkF8XaHyZs4nowHr1QEHYLM7pMaFmIk,4921
174
177
  ai_edge_torch/generative/layers/einsum.py,sha256=EsZSWNVWUs0-1plp4TBnhP4ZhaRDBa2VlDO6hWpUAqU,1288
175
178
  ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
176
- ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
177
- ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKrDt249G5Mz-8VKWW7_WHx0u4,1655
179
+ ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
180
+ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
178
181
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
179
182
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
180
- ai_edge_torch/generative/layers/model_config.py,sha256=H1MpjP1Ij1r4DEcE4cQ_6A8h0QvUjCkuGATXMkIMIWg,8570
183
+ ai_edge_torch/generative/layers/model_config.py,sha256=0FH3UJPVnEhgBO4eUlNaHuQBDo_OKH17ChG5-Ybj2T4,9895
181
184
  ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
182
185
  ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
183
186
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
@@ -212,7 +215,7 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSyd
212
215
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
213
216
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
214
217
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
215
- ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
218
+ ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=l54bmmhj613eB2oCoONIAKEHhf8TQOhC9Gwjp6lxHAE,1659
216
219
  ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
217
220
  ai_edge_torch/generative/utilities/verifier.py,sha256=ETO2ShU5KXG7MLP8eVOWuzuRLCUtapafYHcZ6TZHIkw,13061
218
221
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
@@ -264,8 +267,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
264
267
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
265
268
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
266
269
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
267
- ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
268
- ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/METADATA,sha256=_UC8q7Xe3xMUCwKKbF4CJ5hewK9PLIJ26ksKCAeWjik,2074
269
- ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
270
- ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
271
- ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/RECORD,,
270
+ ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
271
+ ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/METADATA,sha256=rVs5qa-WVOxoGTyFSWL9oeK9t6or0QJjsk4Cyr7IYpM,2074
272
+ ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
273
+ ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
274
+ ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/RECORD,,