ai-edge-torch-nightly 0.3.0.dev20240911__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/smallm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smallm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240911.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240911.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +19 -15
- {ai_edge_torch_nightly-0.3.0.dev20240911.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240911.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240911.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting SmalLM model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
import ai_edge_torch
|
22
|
+
from ai_edge_torch.generative.examples.smallm import smallm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
+
import torch
|
26
|
+
|
27
|
+
|
28
|
+
def convert_smallm_to_tflite(
|
29
|
+
checkpoint_path: str,
|
30
|
+
prefill_seq_len: int = 512,
|
31
|
+
kv_cache_max_len: int = 1024,
|
32
|
+
quantize: bool = True,
|
33
|
+
):
|
34
|
+
"""Converts SmalLM model to multi-signature tflite model.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
39
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
+
Defaults to 512.
|
41
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
+
including both prefill and decode. Defaults to 1024.
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
45
|
+
"""
|
46
|
+
pytorch_model = smallm.build_model(
|
47
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
|
+
)
|
49
|
+
# Tensors used to trace the model graph during conversion.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
+
|
56
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
+
edge_model = (
|
58
|
+
ai_edge_torch.signature(
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
75
|
+
)
|
76
|
+
.convert(quant_config=quant_config)
|
77
|
+
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
+
edge_model.export(
|
80
|
+
f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
if __name__ == '__main__':
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
|
86
|
+
convert_smallm_to_tflite(path)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a SmalLM model."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
import os
|
20
|
+
import pathlib
|
21
|
+
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import numpy as np
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
31
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
32
|
+
TENSOR_NAMES.lm_head = None
|
33
|
+
|
34
|
+
|
35
|
+
class SmalLM(tiny_llama.TinyLlama):
|
36
|
+
"""A SmalLM model built from the Edge Generative API layers.
|
37
|
+
|
38
|
+
SmalLM shares the same architecture as TinyLlama, but with different model
|
39
|
+
sizes.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, config: cfg.ModelConfig):
|
43
|
+
super().__init__(config)
|
44
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
45
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
46
|
+
|
47
|
+
|
48
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
49
|
+
"""Returns the model config for a SmalLM 135M model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
53
|
+
is 1024.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The model config for a SmalLM model.
|
57
|
+
"""
|
58
|
+
attn_config = cfg.AttentionConfig(
|
59
|
+
num_heads=9,
|
60
|
+
head_dim=64,
|
61
|
+
num_query_groups=3,
|
62
|
+
rotary_percentage=1.0,
|
63
|
+
)
|
64
|
+
ff_config = cfg.FeedForwardConfig(
|
65
|
+
type=cfg.FeedForwardType.GATED,
|
66
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
67
|
+
intermediate_size=1536,
|
68
|
+
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
70
|
+
config = cfg.ModelConfig(
|
71
|
+
vocab_size=49152,
|
72
|
+
num_layers=30,
|
73
|
+
max_seq_len=2048,
|
74
|
+
embedding_dim=576,
|
75
|
+
kv_cache_max_len=kv_cache_max_len,
|
76
|
+
attn_config=attn_config,
|
77
|
+
ff_config=ff_config,
|
78
|
+
pre_attention_norm_config=norm_config,
|
79
|
+
post_attention_norm_config=norm_config,
|
80
|
+
final_norm_config=norm_config,
|
81
|
+
enable_hlfb=True,
|
82
|
+
)
|
83
|
+
return config
|
84
|
+
|
85
|
+
|
86
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
87
|
+
config = get_model_config(**kwargs)
|
88
|
+
model = SmalLM(config)
|
89
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
90
|
+
# since embedding and lm-head use the same weight, we need to set strict
|
91
|
+
# to False.
|
92
|
+
loader.load(model, strict=False)
|
93
|
+
model.eval()
|
94
|
+
return model
|
95
|
+
|
96
|
+
|
97
|
+
def define_and_run(checkpoint_path: str) -> None:
|
98
|
+
"""Instantiates and runs a SmalLM model."""
|
99
|
+
|
100
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
101
|
+
smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
|
102
|
+
kv_cache_max_len = 1024
|
103
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
104
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
105
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
106
|
+
tokens[0, :4] = idx
|
107
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
108
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
109
|
+
output = model.forward(tokens, input_pos, kv)
|
110
|
+
assert torch.allclose(
|
111
|
+
smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
if __name__ == "__main__":
|
116
|
+
input_checkpoint_path = os.path.join(
|
117
|
+
pathlib.Path.home(), "Downloads/llm_data/smallm"
|
118
|
+
)
|
119
|
+
define_and_run(input_checkpoint_path)
|
@@ -44,7 +44,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
44
44
|
)
|
45
45
|
|
46
46
|
|
47
|
-
class
|
47
|
+
class TinyLlama(nn.Module):
|
48
48
|
"""A TinyLlama model built from the Edge Generative API layers."""
|
49
49
|
|
50
50
|
def __init__(self, config: cfg.ModelConfig):
|
@@ -169,7 +169,7 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
169
169
|
|
170
170
|
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
171
171
|
config = get_model_config(**kwargs)
|
172
|
-
model =
|
172
|
+
model = TinyLlama(config)
|
173
173
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
174
174
|
loader.load(model)
|
175
175
|
model.eval()
|
@@ -59,9 +59,11 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
59
59
|
zero_centered_gamma=config.zero_centered,
|
60
60
|
)
|
61
61
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
62
|
-
return
|
62
|
+
return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
|
63
63
|
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
64
|
-
return
|
64
|
+
return normalization.GroupNorm(
|
65
|
+
config.group_num, dim, config.epsilon, config.enable_hlfb
|
66
|
+
)
|
65
67
|
else:
|
66
68
|
raise ValueError("Unsupported norm type.")
|
67
69
|
|
@@ -104,6 +104,7 @@ class NormalizationConfig:
|
|
104
104
|
"""Normalizater parameters."""
|
105
105
|
|
106
106
|
type: NormalizationType = NormalizationType.NONE
|
107
|
+
enable_hlfb: bool = False
|
107
108
|
epsilon: float = 1e-5
|
108
109
|
zero_centered: bool = False
|
109
110
|
# Number of groups used in group normalization.
|
@@ -14,7 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Common normalization layers.
|
16
16
|
|
17
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
17
18
|
import torch
|
19
|
+
from torch import nn
|
20
|
+
import torch.nn.functional as F
|
18
21
|
|
19
22
|
|
20
23
|
# Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
|
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
|
|
58
61
|
return output * (1 + self.weight)
|
59
62
|
else:
|
60
63
|
return output * self.weight
|
64
|
+
|
65
|
+
|
66
|
+
class GroupNorm(torch.nn.Module):
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
group_num: int,
|
71
|
+
dim: int,
|
72
|
+
eps: float = 1e-5,
|
73
|
+
enable_hlfb: bool = False,
|
74
|
+
):
|
75
|
+
"""Initialize the GroupNorm layer.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
group_num (int): Number of groups to separate the channels into.
|
79
|
+
dim (int): Dimension of the input tensor.
|
80
|
+
eps (float): A small float value to ensure numerical stability (default:
|
81
|
+
1e-6).
|
82
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
83
|
+
op.
|
84
|
+
"""
|
85
|
+
super().__init__()
|
86
|
+
self.enable_hlfb = enable_hlfb
|
87
|
+
self.group_num = group_num
|
88
|
+
self.eps = eps
|
89
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
90
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
91
|
+
|
92
|
+
def forward(self, x):
|
93
|
+
"""Running the forward pass of GroupNorm layer.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
x (torch.Tensor): input tensor.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
torch.Tensor: output tensor after applying GroupNorm.
|
100
|
+
"""
|
101
|
+
if self.enable_hlfb:
|
102
|
+
return group_norm_with_hlfb(
|
103
|
+
x,
|
104
|
+
self.weight,
|
105
|
+
self.bias,
|
106
|
+
self.group_num,
|
107
|
+
self.eps,
|
108
|
+
)
|
109
|
+
else:
|
110
|
+
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
111
|
+
|
112
|
+
|
113
|
+
class LayerNorm(torch.nn.Module):
|
114
|
+
|
115
|
+
def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
|
116
|
+
"""Initialize the LayerNorm layer.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
dim (int): dimension of the input tensor.
|
120
|
+
eps (float): A small float value to ensure numerical stability (default:
|
121
|
+
1e-6).
|
122
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
123
|
+
op.
|
124
|
+
"""
|
125
|
+
super().__init__()
|
126
|
+
self.enable_hlfb = enable_hlfb
|
127
|
+
self.eps = eps
|
128
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
129
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
130
|
+
|
131
|
+
def forward(self, x):
|
132
|
+
"""Running the forward pass of LayerNorm layer.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
x (torch.Tensor): input tensor.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
torch.Tensor: output tensor after applying LayerNorm.
|
139
|
+
"""
|
140
|
+
if self.enable_hlfb:
|
141
|
+
return layer_norm_with_hlfb(
|
142
|
+
x,
|
143
|
+
self.weight,
|
144
|
+
self.bias,
|
145
|
+
self.eps,
|
146
|
+
)
|
147
|
+
else:
|
148
|
+
return F.layer_norm(
|
149
|
+
x,
|
150
|
+
x.shape,
|
151
|
+
self.weight.broadcast_to(x.shape),
|
152
|
+
self.bias.broadcast_to(x.shape),
|
153
|
+
self.eps,
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
def group_norm_with_hlfb(
|
158
|
+
x: torch.Tensor,
|
159
|
+
w: torch.Tensor,
|
160
|
+
b: torch.Tensor,
|
161
|
+
num_groups: int,
|
162
|
+
eps: float,
|
163
|
+
):
|
164
|
+
"""Group Normalization with high-level function boundary enabled.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
|
168
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
169
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
170
|
+
num_groups (int): Number of groups to separate the channels into.
|
171
|
+
eps (float): A small float value to ensure numerical stability.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
The output tensor of Group Normalization.
|
175
|
+
"""
|
176
|
+
x = torch.permute(x, (0, 2, 3, 1))
|
177
|
+
|
178
|
+
builder = StableHLOCompositeBuilder(
|
179
|
+
name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
|
180
|
+
)
|
181
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
182
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
183
|
+
y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
|
184
|
+
y = torch.permute(y, (0, 2, 3, 1))
|
185
|
+
y = builder.mark_outputs(y)
|
186
|
+
|
187
|
+
y = torch.permute(y, (0, 3, 1, 2))
|
188
|
+
return y
|
189
|
+
|
190
|
+
|
191
|
+
def layer_norm_with_hlfb(
|
192
|
+
x: torch.Tensor,
|
193
|
+
w: torch.Tensor,
|
194
|
+
b: torch.Tensor,
|
195
|
+
eps: float,
|
196
|
+
):
|
197
|
+
"""Layer Normalization with high-level function boundary enabled.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
x (torch.Tensor): Input tensor for Layer Normalization.
|
201
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
202
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
203
|
+
eps (float): A small float value to ensure numerical stability.
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
The output tensor of Layer Normalization.
|
207
|
+
"""
|
208
|
+
builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
|
209
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
210
|
+
y = F.layer_norm(
|
211
|
+
x,
|
212
|
+
x.shape,
|
213
|
+
weight=w.broadcast_to(x.shape),
|
214
|
+
bias=b.broadcast_to(x.shape),
|
215
|
+
eps=eps,
|
216
|
+
)
|
217
|
+
y = builder.mark_outputs(y)
|
218
|
+
return y
|
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
|
|
122
122
|
config.attention_batch_size,
|
123
123
|
config.dim,
|
124
124
|
config.attention_config,
|
125
|
-
0,
|
126
125
|
enable_hlfb=config.enable_hlfb,
|
127
126
|
)
|
128
127
|
|
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
|
|
180
179
|
config.query_dim,
|
181
180
|
config.cross_dim,
|
182
181
|
config.attention_config,
|
183
|
-
0,
|
184
182
|
enable_hlfb=config.enable_hlfb,
|
185
183
|
)
|
186
184
|
|
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
|
|
71
71
|
safetensors.torch.save_file(test_weights, file_path)
|
72
72
|
cfg = tiny_llama.get_model_config()
|
73
73
|
cfg.num_layers = 1
|
74
|
-
model = tiny_llama.
|
74
|
+
model = tiny_llama.TinyLlama(cfg)
|
75
75
|
|
76
76
|
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
77
77
|
# if returns successfully, it means all the tensors were initiallized.
|
@@ -123,7 +123,7 @@ class TestModelConversion(googletest.TestCase):
|
|
123
123
|
)
|
124
124
|
def test_tiny_llama_multisig(self):
|
125
125
|
config = tiny_llama.get_fake_model_config()
|
126
|
-
pytorch_model = tiny_llama.
|
126
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
127
127
|
|
128
128
|
# prefill
|
129
129
|
seq_len = 10
|
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
|
|
167
167
|
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
|
168
168
|
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
|
169
169
|
lower_by_torch_xla2(torch.ops.aten.native_group_norm)
|
170
|
-
lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
|
171
170
|
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
|
172
171
|
lower_by_torch_xla2(torch.ops.aten.ne)
|
173
172
|
lower_by_torch_xla2(torch.ops.aten.neg)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Provides lowering for coreaten to stablehlo for LayerNorm."""
|
16
|
+
|
17
|
+
import math
|
18
|
+
from typing import Optional
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import utils
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
|
27
|
+
# Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
|
28
|
+
@registry.lower(torch.ops.aten.native_layer_norm)
|
29
|
+
def _aten_native_layer_norm(
|
30
|
+
lctx,
|
31
|
+
data: ir.Value,
|
32
|
+
normalized_shape: list[int],
|
33
|
+
weight: Optional[ir.Value],
|
34
|
+
bias: Optional[ir.Value],
|
35
|
+
eps: float,
|
36
|
+
):
|
37
|
+
data_type: ir.RankedTensorType = data.type
|
38
|
+
unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
|
39
|
+
dest_shape = [
|
40
|
+
1,
|
41
|
+
unnormalized_count,
|
42
|
+
math.prod(normalized_shape),
|
43
|
+
]
|
44
|
+
dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
|
45
|
+
|
46
|
+
reshaped_data = stablehlo.reshape(dest_type, data)
|
47
|
+
|
48
|
+
one = utils.splat(1, data_type.element_type, [unnormalized_count])
|
49
|
+
zero = utils.splat(0, data_type.element_type, [unnormalized_count])
|
50
|
+
output, mean, var = stablehlo.batch_norm_training(
|
51
|
+
reshaped_data, one, zero, eps, 1
|
52
|
+
)
|
53
|
+
eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
|
54
|
+
rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
|
55
|
+
|
56
|
+
stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
|
57
|
+
normalized_shape
|
58
|
+
)
|
59
|
+
stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
|
60
|
+
mean = stablehlo.reshape(stats_type, mean)
|
61
|
+
rstd = stablehlo.reshape(stats_type, rstd)
|
62
|
+
|
63
|
+
output = stablehlo.reshape(data_type, output)
|
64
|
+
|
65
|
+
data_rank = len(data_type.shape)
|
66
|
+
normalized_rank = len(normalized_shape)
|
67
|
+
if weight is not None:
|
68
|
+
weight = stablehlo.broadcast_in_dim(
|
69
|
+
data_type, weight, list(range(data_rank - normalized_rank, data_rank))
|
70
|
+
)
|
71
|
+
output = stablehlo.multiply(weight, output)
|
72
|
+
if bias is not None:
|
73
|
+
bias = stablehlo.broadcast_in_dim(
|
74
|
+
data_type, bias, list(range(data_rank - normalized_rank, data_rank))
|
75
|
+
)
|
76
|
+
output = stablehlo.add(bias, output)
|
77
|
+
|
78
|
+
return output, mean, rstd
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240912
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=Li1VzlXx5ExydpfV93yVAd78cF1L_g3x30-daYdgsLA,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -47,6 +47,9 @@ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=X6WfUCDJDEqyyEAYGq1lmKt
|
|
47
47
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
48
48
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
|
49
49
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=BzvUrClFx5HKf6PYzJc7ba2O3AwYUJE485u5GSOiPy4,6851
|
50
|
+
ai_edge_torch/generative/examples/smallm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
51
|
+
ai_edge_torch/generative/examples/smallm/convert_to_tflite.py,sha256=aqqxQMBBO_dtGB1iZ1tpF8hbGpdZkx0VIz62ZqfVMCc,3036
|
52
|
+
ai_edge_torch/generative/examples/smallm/smallm.py,sha256=j7SDdcX0WvgQWgpaAi7Gi39Jf0-w9D9PftDbugNrN1M,3919
|
50
53
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
51
54
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
52
55
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
|
@@ -71,21 +74,21 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_
|
|
71
74
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=PbWpfg3AOEZjI1FlnZCxRD-kIKtdkR9AOZ6l-9-TpRA,5664
|
72
75
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
73
76
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
|
74
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
77
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=HwoEWls-uJ7oHj0HYxJtgZZhgiBR_OQPXlR6l14vm5E,6778
|
75
78
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
|
76
79
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
|
77
80
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
81
|
ai_edge_torch/generative/layers/attention.py,sha256=ee0KHRakhjLjawP32FY2EntxOkyPvjiEZChLnBn_HPc,12601
|
79
82
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
80
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
83
|
+
ai_edge_torch/generative/layers/builder.py,sha256=KMwMfZ08r5CXHhcPVZ72nZnIAcsMAIKsv7-QPntlqgI,4418
|
81
84
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
|
82
85
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=WDu03NQwkDCrrrT9Du_3ZOxlURZz3XDbS1PLzFozhMI,6013
|
83
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
84
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
86
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=03tjidDM1uo_H0jsHNjYEUR5R1FEckc1GIxSoE7ItQQ,5780
|
87
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
|
85
88
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
86
89
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
87
90
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
88
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
91
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
|
89
92
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
90
93
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
|
91
94
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -97,8 +100,8 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
|
|
97
100
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
98
101
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
99
102
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=FU2rmU03Lp-vZ5wWXXCao1WEw7xbpqebFMANL_O2chA,3713
|
100
|
-
ai_edge_torch/generative/test/test_loader.py,sha256=
|
101
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
103
|
+
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
104
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=SIv7_sc5qHvbHFN8SbAfY00iXGvH7J6cJLkERU_cd5k,5888
|
102
105
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
|
103
106
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
104
107
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
@@ -135,11 +138,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
|
|
135
138
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
136
139
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
137
140
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
138
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
141
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
139
142
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
|
140
143
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
141
144
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
142
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
145
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Ii1akrKLhRTkZ715JxXBBGKv3jGfXReXMQCYNzSnxmM,10567
|
146
|
+
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
143
147
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
144
148
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
145
149
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
@@ -151,8 +155,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
151
155
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
152
156
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
153
157
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
154
|
-
ai_edge_torch_nightly-0.3.0.
|
155
|
-
ai_edge_torch_nightly-0.3.0.
|
156
|
-
ai_edge_torch_nightly-0.3.0.
|
157
|
-
ai_edge_torch_nightly-0.3.0.
|
158
|
-
ai_edge_torch_nightly-0.3.0.
|
158
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
159
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/METADATA,sha256=EjeMjRJ5PeW8Azc8hoiJeMP_WaHUDlCend4DFIeQnzc,1859
|
160
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
161
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
162
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/RECORD,,
|
File without changes
|
File without changes
|