ai-edge-torch-nightly 0.6.0.dev20250521__py3-none-any.whl → 0.6.0.dev20250523__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +61 -0
- ai_edge_torch/generative/examples/qwen/qwen3.py +171 -0
- ai_edge_torch/generative/examples/qwen/{verify.py → verify_qwen2.py} +1 -0
- ai_edge_torch/generative/examples/qwen/verify_qwen3.py +59 -0
- ai_edge_torch/generative/examples/qwen/verify_util.py +15 -3
- ai_edge_torch/generative/layers/attention.py +39 -8
- ai_edge_torch/generative/layers/builder.py +5 -10
- ai_edge_torch/generative/layers/feed_forward.py +86 -41
- ai_edge_torch/generative/layers/feed_forward_test.py +15 -10
- ai_edge_torch/generative/layers/model_config.py +33 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +5 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250521.dist-info → ai_edge_torch_nightly-0.6.0.dev20250523.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250521.dist-info → ai_edge_torch_nightly-0.6.0.dev20250523.dist-info}/RECORD +17 -14
- {ai_edge_torch_nightly-0.6.0.dev20250521.dist-info → ai_edge_torch_nightly-0.6.0.dev20250523.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250521.dist-info → ai_edge_torch_nightly-0.6.0.dev20250523.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250521.dist-info → ai_edge_torch_nightly-0.6.0.dev20250523.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting Qwen 3.0 models to multi-signature tflite model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from ai_edge_torch.generative.examples.qwen import qwen3
|
20
|
+
from ai_edge_torch.generative.utilities import converter
|
21
|
+
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
23
|
+
|
24
|
+
flags = converter.define_conversion_flags('qwen')
|
25
|
+
|
26
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
27
|
+
'model_size',
|
28
|
+
'1.7b',
|
29
|
+
['0.6b', '1.7b', '4b'],
|
30
|
+
'The size of the model to convert.',
|
31
|
+
)
|
32
|
+
|
33
|
+
_BUILDER = {
|
34
|
+
'0.6b': qwen3.build_0_6b_model,
|
35
|
+
'1.7b': qwen3.build_1_7b_model,
|
36
|
+
'4b': qwen3.build_4b_model,
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
def main(_):
|
41
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
42
|
+
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
43
|
+
checkpoint_path,
|
44
|
+
custom_loader=loader.maybe_get_custom_loader(
|
45
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
|
+
),
|
47
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
48
|
+
)
|
49
|
+
converter.convert_to_tflite(
|
50
|
+
pytorch_model,
|
51
|
+
output_path=flags.FLAGS.output_path,
|
52
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
53
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
54
|
+
quantize=flags.FLAGS.quantize,
|
55
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
56
|
+
export_config=export_config.get_from_flags(),
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == '__main__':
|
61
|
+
app.run(main)
|
@@ -0,0 +1,171 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building Qwen 3.0 models."""
|
17
|
+
|
18
|
+
from typing import Callable, Dict
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
from ai_edge_torch.generative.utilities import loader as loading_utils
|
21
|
+
from ai_edge_torch.generative.utilities import model_builder
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
26
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
27
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
28
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
29
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
30
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
31
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
32
|
+
attn_query_norm="model.layers.{}.self_attn.q_norm",
|
33
|
+
attn_key_norm="model.layers.{}.self_attn.k_norm",
|
34
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
35
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
36
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
37
|
+
embedding="model.embed_tokens",
|
38
|
+
final_norm="model.norm",
|
39
|
+
lm_head="lm_head",
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class Qwen3(model_builder.DecoderOnlyModel):
|
44
|
+
"""A Qwen3 model built from the Edge Generative API layers."""
|
45
|
+
|
46
|
+
pass
|
47
|
+
|
48
|
+
|
49
|
+
def get_4b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
50
|
+
"""Returns the model config for a Qwen 3.0 4B model.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
54
|
+
is 1024.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
The model config for a SmolLM model.
|
58
|
+
"""
|
59
|
+
norm_config = cfg.NormalizationConfig(
|
60
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
61
|
+
)
|
62
|
+
attn_config = cfg.AttentionConfig(
|
63
|
+
num_heads=32,
|
64
|
+
head_dim=128,
|
65
|
+
num_query_groups=8,
|
66
|
+
query_norm_config=norm_config,
|
67
|
+
key_norm_config=norm_config,
|
68
|
+
rotary_base=1000000,
|
69
|
+
rotary_percentage=1.0,
|
70
|
+
qkv_use_bias=False,
|
71
|
+
qkv_transpose_before_split=True,
|
72
|
+
qkv_fused_interleaved=False, # No interleaved qkv projection.
|
73
|
+
)
|
74
|
+
ff_config = cfg.FeedForwardConfig(
|
75
|
+
type=cfg.FeedForwardType.GATED,
|
76
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
77
|
+
intermediate_size=9728,
|
78
|
+
)
|
79
|
+
block_config = cfg.TransformerBlockConfig(
|
80
|
+
attn_config=attn_config,
|
81
|
+
ff_config=ff_config,
|
82
|
+
pre_attention_norm_config=norm_config,
|
83
|
+
post_attention_norm_config=norm_config,
|
84
|
+
)
|
85
|
+
config = cfg.ModelConfig(
|
86
|
+
vocab_size=151936,
|
87
|
+
num_layers=36,
|
88
|
+
max_seq_len=40960,
|
89
|
+
embedding_dim=2560,
|
90
|
+
kv_cache_max_len=kv_cache_max_len,
|
91
|
+
block_configs=block_config,
|
92
|
+
final_norm_config=norm_config,
|
93
|
+
)
|
94
|
+
return config
|
95
|
+
|
96
|
+
|
97
|
+
def get_1_7b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
98
|
+
"""Returns the model config for a Qwen 3.0 1.7B model."""
|
99
|
+
config = get_4b_model_config(kv_cache_max_len)
|
100
|
+
# Qwen has only one block config.
|
101
|
+
block_config = config.block_config(0)
|
102
|
+
block_config.attn_config.num_heads = 16
|
103
|
+
block_config.attn_config.head_dim = 128
|
104
|
+
block_config.ff_config.intermediate_size = 6144
|
105
|
+
config.num_layers = 28
|
106
|
+
config.embedding_dim = 2048
|
107
|
+
return config
|
108
|
+
|
109
|
+
|
110
|
+
def get_0_6b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
111
|
+
"""Returns the model config for a Qwen 3.0 0.6B model."""
|
112
|
+
config = get_4b_model_config(kv_cache_max_len)
|
113
|
+
# Qwen has only one block config.
|
114
|
+
block_config = config.block_config(0)
|
115
|
+
block_config.attn_config.num_heads = 16
|
116
|
+
block_config.attn_config.head_dim = 128
|
117
|
+
block_config.ff_config.intermediate_size = 3072
|
118
|
+
config.num_layers = 28
|
119
|
+
config.embedding_dim = 1024
|
120
|
+
return config
|
121
|
+
|
122
|
+
|
123
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
124
|
+
config = get_4b_model_config(**kwargs)
|
125
|
+
config.vocab_size = 128
|
126
|
+
config.num_layers = 2
|
127
|
+
# Qwen has only one block config.
|
128
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
129
|
+
return config
|
130
|
+
|
131
|
+
|
132
|
+
def build_4b_model(
|
133
|
+
checkpoint_path: str,
|
134
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
135
|
+
**kwargs
|
136
|
+
) -> nn.Module:
|
137
|
+
return model_builder.build_decoder_only_model(
|
138
|
+
checkpoint_path=checkpoint_path,
|
139
|
+
config=get_4b_model_config(**kwargs),
|
140
|
+
tensor_names=TENSOR_NAMES,
|
141
|
+
model_class=Qwen3,
|
142
|
+
custom_loader=custom_loader,
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
def build_1_7b_model(
|
147
|
+
checkpoint_path: str,
|
148
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
149
|
+
**kwargs
|
150
|
+
) -> nn.Module:
|
151
|
+
return model_builder.build_decoder_only_model(
|
152
|
+
checkpoint_path=checkpoint_path,
|
153
|
+
config=get_1_7b_model_config(**kwargs),
|
154
|
+
tensor_names=TENSOR_NAMES,
|
155
|
+
model_class=Qwen3,
|
156
|
+
custom_loader=custom_loader,
|
157
|
+
)
|
158
|
+
|
159
|
+
|
160
|
+
def build_0_6b_model(
|
161
|
+
checkpoint_path: str,
|
162
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
163
|
+
**kwargs
|
164
|
+
) -> nn.Module:
|
165
|
+
return model_builder.build_decoder_only_model(
|
166
|
+
checkpoint_path=checkpoint_path,
|
167
|
+
config=get_0_6b_model_config(**kwargs),
|
168
|
+
tensor_names=TENSOR_NAMES,
|
169
|
+
model_class=Qwen3,
|
170
|
+
custom_loader=custom_loader,
|
171
|
+
)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Qwen 3.0 0.6B, 1.7B, and 4B models."""
|
17
|
+
|
18
|
+
|
19
|
+
from absl import app
|
20
|
+
from absl import flags
|
21
|
+
from ai_edge_torch.generative.examples.qwen import verify_util
|
22
|
+
|
23
|
+
|
24
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
25
|
+
"model_size",
|
26
|
+
"0.6b",
|
27
|
+
["0.6b", "1.7b", "4b"],
|
28
|
+
"The size of the model to verify.",
|
29
|
+
)
|
30
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
31
|
+
"prompts",
|
32
|
+
"What is the meaning of life?",
|
33
|
+
"The input prompts to generate answers.",
|
34
|
+
)
|
35
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
36
|
+
"max_new_tokens",
|
37
|
+
30,
|
38
|
+
"The maximum size of the generated tokens.",
|
39
|
+
)
|
40
|
+
|
41
|
+
_CHECKPOINT = {
|
42
|
+
"0.6b": "Qwen/Qwen3-0.6B",
|
43
|
+
"1.7b": "Qwen/Qwen3-1.7B",
|
44
|
+
"4b": "Qwen/Qwen3-4B",
|
45
|
+
}
|
46
|
+
|
47
|
+
|
48
|
+
def main(_):
|
49
|
+
verify_util.verify_qwen(
|
50
|
+
model_size=_MODEL_SIZE.value,
|
51
|
+
model_version="v3",
|
52
|
+
checkpoint_dir=_CHECKPOINT[_MODEL_SIZE.value],
|
53
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
54
|
+
prompts=_PROMPTS.value,
|
55
|
+
)
|
56
|
+
|
57
|
+
|
58
|
+
if __name__ == "__main__":
|
59
|
+
app.run(main)
|
@@ -17,24 +17,36 @@ import logging
|
|
17
17
|
import os
|
18
18
|
import pathlib
|
19
19
|
|
20
|
-
from ai_edge_torch.generative.examples.qwen import qwen
|
20
|
+
from ai_edge_torch.generative.examples.qwen import qwen, qwen3
|
21
21
|
from ai_edge_torch.generative.utilities import loader
|
22
22
|
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
23
|
from ai_edge_torch.generative.utilities import verifier
|
24
24
|
import transformers
|
25
25
|
|
26
26
|
|
27
|
-
|
27
|
+
_BUILDER_V2 = {
|
28
28
|
"0.5b": qwen.build_0_5b_model,
|
29
29
|
"1.5b": qwen.build_1_5b_model,
|
30
30
|
"3b": qwen.build_3b_model,
|
31
31
|
}
|
32
32
|
|
33
|
+
_BUILDER_V3 = {
|
34
|
+
"0.6b": qwen3.build_0_6b_model,
|
35
|
+
"1.7b": qwen3.build_1_7b_model,
|
36
|
+
"4b": qwen3.build_4b_model,
|
37
|
+
}
|
38
|
+
|
39
|
+
_BUILDER = {
|
40
|
+
"v2": _BUILDER_V2,
|
41
|
+
"v3": _BUILDER_V3,
|
42
|
+
}
|
43
|
+
|
33
44
|
DEFAULT_PROMPTS = ["What is the meaning of life?"]
|
34
45
|
|
35
46
|
|
36
47
|
def verify_qwen(
|
37
48
|
model_size: str,
|
49
|
+
model_version: str,
|
38
50
|
checkpoint_dir: str,
|
39
51
|
weight_filename: str = "model.safetensors",
|
40
52
|
max_new_tokens: int = 30,
|
@@ -64,7 +76,7 @@ def verify_qwen(
|
|
64
76
|
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
|
65
77
|
|
66
78
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
-
reauthored_model = _BUILDER[model_size](
|
79
|
+
reauthored_model = _BUILDER[model_version][model_size](
|
68
80
|
checkpoint_path=reauthored_checkpoint,
|
69
81
|
custom_loader=custom_loader,
|
70
82
|
)
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Common building blocks for Attention layer."""
|
17
17
|
|
18
|
+
import abc
|
18
19
|
from typing import Optional, Tuple, Union
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import builder
|
@@ -111,7 +112,42 @@ class TransformerBlock(nn.Module):
|
|
111
112
|
return output if kv is None else (output, kv)
|
112
113
|
|
113
114
|
|
114
|
-
class
|
115
|
+
class CausalSelfAttentionBase(nn.Module):
|
116
|
+
"""Base class for causal self attention layer."""
|
117
|
+
|
118
|
+
def __init__(
|
119
|
+
self, dim: int, config: cfg.AttentionConfig, enable_hlfb: bool
|
120
|
+
) -> None:
|
121
|
+
super().__init__()
|
122
|
+
self.dim = dim
|
123
|
+
self.config = config
|
124
|
+
self.enable_hlfb = enable_hlfb
|
125
|
+
|
126
|
+
self.query_norm = builder.build_norm(
|
127
|
+
self.config.head_dim, self.config.query_norm_config
|
128
|
+
)
|
129
|
+
self.key_norm = builder.build_norm(
|
130
|
+
self.config.head_dim, self.config.key_norm_config
|
131
|
+
)
|
132
|
+
self.value_norm = builder.build_norm(
|
133
|
+
self.config.head_dim, self.config.value_norm_config
|
134
|
+
)
|
135
|
+
|
136
|
+
@abc.abstractmethod
|
137
|
+
def forward(
|
138
|
+
self,
|
139
|
+
x: torch.Tensor,
|
140
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
141
|
+
mask: Optional[torch.Tensor] = None,
|
142
|
+
input_pos: Optional[torch.Tensor] = None,
|
143
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
144
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
145
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
146
|
+
raise NotImplementedError()
|
147
|
+
|
148
|
+
|
149
|
+
class CausalSelfAttention(CausalSelfAttentionBase):
|
150
|
+
"""Causal self attention layer implementation."""
|
115
151
|
|
116
152
|
def __init__(
|
117
153
|
self,
|
@@ -126,7 +162,7 @@ class CausalSelfAttention(nn.Module):
|
|
126
162
|
config (cfg.AttentionConfig): attention specific configurations.
|
127
163
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
128
164
|
"""
|
129
|
-
super().__init__()
|
165
|
+
super().__init__(dim, config, enable_hlfb)
|
130
166
|
self.kv_cache = None
|
131
167
|
qkv_shape = (
|
132
168
|
config.num_heads + 2 * config.num_query_groups
|
@@ -137,12 +173,6 @@ class CausalSelfAttention(nn.Module):
|
|
137
173
|
self.output_projection = nn.Linear(
|
138
174
|
output_shape, dim, bias=config.output_proj_use_bias
|
139
175
|
)
|
140
|
-
self.query_norm = builder.build_norm(
|
141
|
-
config.head_dim, config.query_norm_config
|
142
|
-
)
|
143
|
-
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
144
|
-
self.config = config
|
145
|
-
self.enable_hlfb = enable_hlfb
|
146
176
|
|
147
177
|
def forward(
|
148
178
|
self,
|
@@ -204,6 +234,7 @@ class CausalSelfAttention(nn.Module):
|
|
204
234
|
|
205
235
|
q = self.query_norm(q)
|
206
236
|
k = self.key_norm(k)
|
237
|
+
v = self.value_norm(v)
|
207
238
|
|
208
239
|
q = q.reshape(B, T, -1, self.config.head_dim)
|
209
240
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
@@ -15,9 +15,9 @@
|
|
15
15
|
# Builder class for individual components.
|
16
16
|
from typing import Callable
|
17
17
|
|
18
|
+
from ai_edge_torch.generative.layers import normalization
|
18
19
|
import ai_edge_torch.generative.layers.feed_forward as feed_forward
|
19
20
|
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
-
import ai_edge_torch.generative.layers.normalization as normalization
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
import torch.nn.functional as F
|
@@ -74,6 +74,8 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
74
74
|
dim,
|
75
75
|
eps=config.epsilon,
|
76
76
|
zero_centered_gamma=config.zero_centered,
|
77
|
+
with_scale=config.with_scale,
|
78
|
+
scale_shift=config.scale_shift,
|
77
79
|
enable_hlfb=config.enable_hlfb,
|
78
80
|
)
|
79
81
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
@@ -107,20 +109,13 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
|
|
107
109
|
else:
|
108
110
|
raise ValueError("Unsupported feedforward type.")
|
109
111
|
|
110
|
-
activation = get_activation(config.activation)
|
111
|
-
|
112
112
|
pre_ff_norm = build_norm(dim, config.pre_ff_norm_config)
|
113
113
|
post_ff_norm = build_norm(dim, config.post_ff_norm_config)
|
114
114
|
|
115
115
|
return ff_module(
|
116
116
|
dim=dim,
|
117
|
-
|
118
|
-
|
119
|
-
use_bias=config.use_bias,
|
120
|
-
use_glu=(
|
121
|
-
config.activation.type == cfg.ActivationType.GE_GLU
|
122
|
-
or config.activation.type == cfg.ActivationType.SILU_GLU
|
123
|
-
),
|
117
|
+
activation=get_activation(config.activation),
|
118
|
+
config=config,
|
124
119
|
pre_ff_norm=pre_ff_norm,
|
125
120
|
post_ff_norm=post_ff_norm,
|
126
121
|
)
|
@@ -14,45 +14,69 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Common building blocks for FeedForward layers.
|
16
16
|
|
17
|
-
|
17
|
+
import abc
|
18
|
+
from typing import Callable
|
18
19
|
|
20
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
19
21
|
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
|
23
|
-
class
|
25
|
+
class FeedForwardBase(nn.Module):
|
26
|
+
"""Base class for feedforward layer."""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
dim: int,
|
31
|
+
activation: Callable[[torch.Tensor], torch.Tensor],
|
32
|
+
config: cfg.FeedForwardConfig,
|
33
|
+
pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
34
|
+
post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
self.dim = dim
|
38
|
+
self.act = activation
|
39
|
+
self.config = config
|
40
|
+
self.hidden_dim = config.intermediate_size
|
41
|
+
self.use_bias = config.use_bias
|
42
|
+
self.use_glu = (
|
43
|
+
config.activation.type == cfg.ActivationType.GE_GLU
|
44
|
+
or config.activation.type == cfg.ActivationType.SILU_GLU
|
45
|
+
)
|
46
|
+
self.pre_ff_norm = pre_ff_norm
|
47
|
+
self.post_ff_norm = post_ff_norm
|
48
|
+
|
49
|
+
@abc.abstractmethod
|
50
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51
|
+
raise NotImplementedError()
|
52
|
+
|
53
|
+
|
54
|
+
class SequentialFeedForward(FeedForwardBase):
|
24
55
|
"""Vanilla sequential Feedforward with customizable activation."""
|
25
56
|
|
26
57
|
def __init__(
|
27
58
|
self,
|
28
59
|
dim: int,
|
29
|
-
hidden_dim: int,
|
30
60
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
61
|
+
config: cfg.FeedForwardConfig,
|
62
|
+
pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
63
|
+
post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
35
64
|
):
|
36
65
|
"""Init function for feedforward layer.
|
37
66
|
|
38
67
|
Args:
|
39
68
|
dim (int): embedding size.
|
40
|
-
hidden_dim (int): hidden dim size of the feedforward layer.
|
41
69
|
activation (Callable): activation function used in this block.
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
post_ff_norm (Callable): post feedforward norm. Default is None.
|
70
|
+
config (cfg.FeedForwardConfig): feedforward layer configuration.
|
71
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is identity.
|
72
|
+
post_ff_norm (Callable): post feedforward norm. Default is identity.
|
46
73
|
"""
|
47
|
-
super().__init__()
|
48
|
-
self.
|
49
|
-
|
50
|
-
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
|
74
|
+
super().__init__(dim, activation, config, pre_ff_norm, post_ff_norm)
|
75
|
+
if self.use_glu:
|
76
|
+
self.w1 = nn.Linear(dim, self.hidden_dim * 2, bias=self.use_bias)
|
51
77
|
else:
|
52
|
-
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
53
|
-
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
54
|
-
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
55
|
-
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
|
78
|
+
self.w1 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
|
79
|
+
self.w2 = nn.Linear(self.hidden_dim, dim, bias=self.use_bias)
|
56
80
|
|
57
81
|
def forward(self, x):
|
58
82
|
"""Forward pass for Feedforward layer.
|
@@ -68,7 +92,7 @@ class SequentialFeedForward(nn.Module):
|
|
68
92
|
return self.post_ff_norm(out)
|
69
93
|
|
70
94
|
|
71
|
-
class GatedFeedForward(
|
95
|
+
class GatedFeedForward(FeedForwardBase):
|
72
96
|
"""Gated Feedforward with customizable activation.
|
73
97
|
|
74
98
|
https://arxiv.org/pdf/2002.05202v1.pdf
|
@@ -77,34 +101,48 @@ class GatedFeedForward(nn.Module):
|
|
77
101
|
def __init__(
|
78
102
|
self,
|
79
103
|
dim: int,
|
80
|
-
hidden_dim: int,
|
81
104
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
105
|
+
config: cfg.FeedForwardConfig,
|
106
|
+
pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
107
|
+
post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
86
108
|
):
|
87
109
|
"""Init function for feedforward layer.
|
88
110
|
|
89
111
|
Args:
|
90
112
|
dim (int): embedding size.
|
91
|
-
hidden_dim (int): hidden dim size of the feedforward layer.
|
92
113
|
activation (Callable): activation function used in this block.
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
post_ff_norm (Callable): post feedforward norm. Default is None.
|
114
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is identity.
|
115
|
+
post_ff_norm (Callable): post feedforward norm. Default is identity.
|
116
|
+
config (cfg.FeedForwardConfig): feedforward layer configuration.
|
97
117
|
"""
|
98
|
-
super().__init__()
|
99
|
-
|
100
|
-
if use_glu:
|
101
|
-
|
118
|
+
super().__init__(dim, activation, config, pre_ff_norm, post_ff_norm)
|
119
|
+
|
120
|
+
if self.use_glu:
|
121
|
+
assert (
|
122
|
+
self.config.use_separate_gating
|
123
|
+
), 'use_separate_gating must be True for GE_GLU | SILU_GLU activation.'
|
124
|
+
|
125
|
+
if self.config.use_separate_gating:
|
126
|
+
if self.use_glu:
|
127
|
+
self.w1 = nn.Linear(dim, self.hidden_dim * 2, bias=self.use_bias)
|
128
|
+
else:
|
129
|
+
self.w1 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
|
130
|
+
self.w3 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
|
102
131
|
else:
|
103
|
-
self.
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
132
|
+
self.w_gating = nn.Parameter(
|
133
|
+
torch.ones((2, dim, self.hidden_dim), dtype=torch.float32),
|
134
|
+
requires_grad=False,
|
135
|
+
)
|
136
|
+
self.gating_bias = (
|
137
|
+
nn.Parameter(
|
138
|
+
torch.zeros((2, self.hidden_dim), dtype=torch.float32),
|
139
|
+
requires_grad=False,
|
140
|
+
)
|
141
|
+
if self.use_bias
|
142
|
+
else torch.zeros((2, self.hidden_dim), dtype=torch.float32)
|
143
|
+
)
|
144
|
+
|
145
|
+
self.w2 = nn.Linear(self.hidden_dim, dim, bias=self.use_bias)
|
108
146
|
|
109
147
|
def forward(self, x):
|
110
148
|
"""Forward pass for Feedforward layer.
|
@@ -116,5 +154,12 @@ class GatedFeedForward(nn.Module):
|
|
116
154
|
torch.Tensor: output tensor after feedforward.
|
117
155
|
"""
|
118
156
|
x_norm = self.pre_ff_norm(x)
|
119
|
-
|
157
|
+
if self.config.use_separate_gating:
|
158
|
+
out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
|
159
|
+
else:
|
160
|
+
out = self.w2(
|
161
|
+
self.act(torch.matmul(x_norm, self.w_gating[0]) + self.gating_bias[0])
|
162
|
+
* (torch.matmul(x_norm, self.w_gating[1]) + self.gating_bias[1])
|
163
|
+
)
|
164
|
+
|
120
165
|
return self.post_ff_norm(out)
|
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from ai_edge_torch.generative.layers import feed_forward
|
17
|
+
from ai_edge_torch.generative.layers import model_config as cfg
|
17
18
|
import torch
|
18
19
|
import torch.nn.functional as F
|
19
20
|
from absl.testing import absltest as googletest
|
@@ -22,28 +23,32 @@ from absl.testing import absltest as googletest
|
|
22
23
|
class FeedForwardTest(googletest.TestCase):
|
23
24
|
|
24
25
|
def test_sequential_feed_forward(self):
|
26
|
+
ff_config = cfg.FeedForwardConfig(
|
27
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
28
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
29
|
+
intermediate_size=10,
|
30
|
+
use_bias=True,
|
31
|
+
)
|
25
32
|
ff = feed_forward.SequentialFeedForward(
|
26
33
|
dim=10,
|
27
|
-
hidden_dim=10,
|
28
34
|
activation=F.silu,
|
29
|
-
|
30
|
-
use_glu=False,
|
31
|
-
pre_ff_norm=torch.nn.Identity(),
|
32
|
-
post_ff_norm=torch.nn.Identity(),
|
35
|
+
config=ff_config,
|
33
36
|
)
|
34
37
|
x = torch.ones((1, 10))
|
35
38
|
out = ff(x)
|
36
39
|
self.assertEqual(out.shape, (1, 10))
|
37
40
|
|
38
41
|
def test_gated_feed_forward(self):
|
42
|
+
ff_config = cfg.FeedForwardConfig(
|
43
|
+
type=cfg.FeedForwardType.GATED,
|
44
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
45
|
+
intermediate_size=10,
|
46
|
+
use_bias=True,
|
47
|
+
)
|
39
48
|
ff = feed_forward.GatedFeedForward(
|
40
49
|
dim=10,
|
41
|
-
hidden_dim=10,
|
42
50
|
activation=F.silu,
|
43
|
-
|
44
|
-
use_glu=False,
|
45
|
-
pre_ff_norm=torch.nn.Identity(),
|
46
|
-
post_ff_norm=torch.nn.Identity(),
|
51
|
+
config=ff_config,
|
47
52
|
)
|
48
53
|
x = torch.ones((1, 10))
|
49
54
|
out = ff(x)
|
@@ -69,10 +69,32 @@ class NormalizationConfig:
|
|
69
69
|
enable_hlfb: bool = True
|
70
70
|
epsilon: float = 1e-5
|
71
71
|
zero_centered: bool = False
|
72
|
+
# Whether to use a scale parameter in the normalization.
|
73
|
+
with_scale: bool = False
|
74
|
+
# The shift to apply to the scale parameter.
|
75
|
+
scale_shift: float = 0.0
|
72
76
|
# Number of groups used in group normalization.
|
73
77
|
group_num: Optional[float] = None
|
74
78
|
|
75
79
|
|
80
|
+
# Exprimental feature and may subject to change.
|
81
|
+
class KVCacheUpdateStrategy(enum.Enum):
|
82
|
+
"""Different alignment strategies of the KV cache.
|
83
|
+
|
84
|
+
Due to restrictions from different devices, we may need to apply different
|
85
|
+
alignment strategies to the KV cache during Attention layer's cache update.
|
86
|
+
|
87
|
+
Available options:
|
88
|
+
INPLACE: Update the existing cache in place using indexes.
|
89
|
+
PREPEND_LEFT: Append the new kv to the left of the existing cache. When this
|
90
|
+
cache update is applied, the newer kvs will always be prepended at the
|
91
|
+
beginning of the cache.
|
92
|
+
"""
|
93
|
+
|
94
|
+
INPLACE = enum.auto()
|
95
|
+
PREPEND_LEFT = enum.auto()
|
96
|
+
|
97
|
+
|
76
98
|
@dataclasses.dataclass
|
77
99
|
class AttentionConfig:
|
78
100
|
"""Attention model's parameters."""
|
@@ -108,6 +130,12 @@ class AttentionConfig:
|
|
108
130
|
key_norm_config: NormalizationConfig = dataclasses.field(
|
109
131
|
default_factory=NormalizationConfig
|
110
132
|
)
|
133
|
+
# The normalization applied to value projection's output.
|
134
|
+
value_norm_config: NormalizationConfig = dataclasses.field(
|
135
|
+
default_factory=NormalizationConfig
|
136
|
+
)
|
137
|
+
# Whether the KV cache is shared with the previous attention block.
|
138
|
+
kv_shared: bool = False
|
111
139
|
relative_attention_num_buckets: int = 0
|
112
140
|
relative_attention_max_distance: int = 0
|
113
141
|
# Softcap on the output logits.
|
@@ -118,6 +146,8 @@ class AttentionConfig:
|
|
118
146
|
sliding_window_size: Optional[int] = None
|
119
147
|
# The default causal mask value used by attention layer.
|
120
148
|
causal_mask_value: float = float("-inf")
|
149
|
+
# The update strategy of the KV cache. Default to INPLACE.
|
150
|
+
kvcache_update_strategy: KVCacheUpdateStrategy = KVCacheUpdateStrategy.INPLACE
|
121
151
|
|
122
152
|
|
123
153
|
@dataclasses.dataclass
|
@@ -135,6 +165,9 @@ class FeedForwardConfig:
|
|
135
165
|
type: FeedForwardType
|
136
166
|
activation: ActivationConfig
|
137
167
|
intermediate_size: int
|
168
|
+
# Whether to use two separate gating parameters or a single one in
|
169
|
+
# GatedFeedForward.
|
170
|
+
use_separate_gating: bool = True
|
138
171
|
use_bias: bool = False
|
139
172
|
# The normalization applied to feed forward's input.
|
140
173
|
pre_ff_norm_config: NormalizationConfig = dataclasses.field(
|
@@ -15,8 +15,6 @@
|
|
15
15
|
|
16
16
|
"""Utilities for the models predefined in HuggingFace transformers."""
|
17
17
|
|
18
|
-
from typing import cast
|
19
|
-
|
20
18
|
from ai_edge_torch.generative.utilities import verifier
|
21
19
|
import torch
|
22
20
|
import transformers
|
@@ -39,4 +37,8 @@ class TransformersModelWrapper(verifier.ModelWrapper):
|
|
39
37
|
self, inputs: torch.Tensor, max_new_tokens: int
|
40
38
|
) -> torch.IntTensor:
|
41
39
|
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
-
|
40
|
+
# Do not override GenerationConfig with model defaults. Always keep greedy
|
41
|
+
# sampling.
|
42
|
+
return self.model.generate(
|
43
|
+
inputs=inputs, generation_config=gen_config, use_model_defaults=False
|
44
|
+
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.6.0.
|
3
|
+
Version: 0.6.0.dev20250523
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=GtHv2onfhAfdaCEqSWpqO8k8_lxn7A37AJJnPbucqbI,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -118,9 +118,12 @@ ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=2MlgQrfRkhE7Dya8MIix
|
|
118
118
|
ai_edge_torch/generative/examples/phi/verify_util.py,sha256=kRREOMSikn_BRbTDkQiXBllPZwmWHa9KUk-kK5lCkbU,2945
|
119
119
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
120
120
|
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=TnzyARHQgmWeOdYsV9WpRj5vhKGBH0kAbp3tMj8ZCYw,1998
|
121
|
+
ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py,sha256=GVV8CVj3rdgt_ZTOlpLSa6AD1pMMpMnZEuowzN2AIGM,2004
|
121
122
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=EcIHVeBcJLc290TiPkPfE7jdG_VXZYKlVGf0XQXzqo8,4554
|
122
|
-
ai_edge_torch/generative/examples/qwen/
|
123
|
-
ai_edge_torch/generative/examples/qwen/
|
123
|
+
ai_edge_torch/generative/examples/qwen/qwen3.py,sha256=g6aVHjnlPo4YhLjSdXxONaDcKT3fZOh8cewlvf3cfoQ,5554
|
124
|
+
ai_edge_torch/generative/examples/qwen/verify_qwen2.py,sha256=ry-c2QesH-0KnrSQygfjUFs6d4kOFvJz2ts_8mP156I,1659
|
125
|
+
ai_edge_torch/generative/examples/qwen/verify_qwen3.py,sha256=hmE0gdyzgcDpEDcWiwOzKQcxt4XeAe9DPRspy_I-lc8,1628
|
126
|
+
ai_edge_torch/generative/examples/qwen/verify_util.py,sha256=vPROwLRABTChMGo5yWJkZURXP6TKWgh5FJj1Z3Zs6HU,3153
|
124
127
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
125
128
|
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=BM-ed7KrmPwzI3MvDs2R7P-kJgE1SK_cNVqIfXhtJjs,2411
|
126
129
|
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=plOi-3LltxReW_HVxhxwee_rYCQq-gsOwbGZtRsM8N8,4443
|
@@ -166,18 +169,18 @@ ai_edge_torch/generative/examples/tiny_llama/verify_util.py,sha256=_zYGqP4HO_Stc
|
|
166
169
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
167
170
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
168
171
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
169
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
172
|
+
ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6imHQ-aCXtmPXCh_o,13911
|
170
173
|
ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
|
171
174
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
172
175
|
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
|
173
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
176
|
+
ai_edge_torch/generative/layers/builder.py,sha256=2bUgkyowDkDznkF8XaHyZs4nowHr1QEHYLM7pMaFmIk,4921
|
174
177
|
ai_edge_torch/generative/layers/einsum.py,sha256=EsZSWNVWUs0-1plp4TBnhP4ZhaRDBa2VlDO6hWpUAqU,1288
|
175
178
|
ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
|
176
|
-
ai_edge_torch/generative/layers/feed_forward.py,sha256=
|
177
|
-
ai_edge_torch/generative/layers/feed_forward_test.py,sha256=
|
179
|
+
ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
|
180
|
+
ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
|
178
181
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
|
179
182
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
180
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
183
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=0FH3UJPVnEhgBO4eUlNaHuQBDo_OKH17ChG5-Ybj2T4,9895
|
181
184
|
ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
|
182
185
|
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
183
186
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
@@ -212,7 +215,7 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSyd
|
|
212
215
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
213
216
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
214
217
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
215
|
-
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=
|
218
|
+
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=l54bmmhj613eB2oCoONIAKEHhf8TQOhC9Gwjp6lxHAE,1659
|
216
219
|
ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
217
220
|
ai_edge_torch/generative/utilities/verifier.py,sha256=ETO2ShU5KXG7MLP8eVOWuzuRLCUtapafYHcZ6TZHIkw,13061
|
218
221
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
@@ -264,8 +267,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
264
267
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
265
268
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
266
269
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
267
|
-
ai_edge_torch_nightly-0.6.0.
|
268
|
-
ai_edge_torch_nightly-0.6.0.
|
269
|
-
ai_edge_torch_nightly-0.6.0.
|
270
|
-
ai_edge_torch_nightly-0.6.0.
|
271
|
-
ai_edge_torch_nightly-0.6.0.
|
270
|
+
ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
271
|
+
ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/METADATA,sha256=rVs5qa-WVOxoGTyFSWL9oeK9t6or0QJjsk4Cyr7IYpM,2074
|
272
|
+
ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
273
|
+
ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
274
|
+
ai_edge_torch_nightly-0.6.0.dev20250523.dist-info/RECORD,,
|
File without changes
|
File without changes
|