ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -30,18 +30,27 @@ class SequentialFeedForward(nn.Module):
|
|
30
30
|
hidden_dim: int,
|
31
31
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
32
32
|
use_bias=False,
|
33
|
+
use_glu=False,
|
33
34
|
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
34
35
|
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
35
36
|
):
|
36
37
|
"""Init function for feedforward layer.
|
37
38
|
|
38
|
-
Args:
|
39
|
-
|
40
|
-
|
39
|
+
Args:
|
40
|
+
dim (int): embedding size.
|
41
|
+
hidden_dim (int): hidden dim size of the feedforward layer.
|
42
|
+
activation (Callable): activation function used in this block.
|
43
|
+
use_bias (Boolean): whether to use bias. Default is false.
|
44
|
+
use_glu (Boolean): whether to use glu in activation. Default is false.
|
45
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is None.
|
46
|
+
post_ff_norm (Callable): post feedforward norm. Default is None.
|
41
47
|
"""
|
42
48
|
super().__init__()
|
43
49
|
self.act = activation
|
44
|
-
|
50
|
+
if use_glu:
|
51
|
+
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
|
52
|
+
else:
|
53
|
+
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
45
54
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
46
55
|
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
47
56
|
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
|
@@ -72,18 +81,27 @@ class GatedFeedForward(nn.Module):
|
|
72
81
|
hidden_dim: int,
|
73
82
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
74
83
|
use_bias=False,
|
84
|
+
use_glu=False,
|
75
85
|
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
76
86
|
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
77
87
|
):
|
78
88
|
"""Init function for feedforward layer.
|
79
89
|
|
80
|
-
Args:
|
81
|
-
|
82
|
-
|
90
|
+
Args:
|
91
|
+
dim (int): embedding size.
|
92
|
+
hidden_dim (int): hidden dim size of the feedforward layer.
|
93
|
+
activation (Callable): activation function used in this block.
|
94
|
+
use_bias (Boolean): whether to use bias. Default is false.
|
95
|
+
use_glu (Boolean): whether to use glu in activation. Default is false.
|
96
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is None.
|
97
|
+
post_ff_norm (Callable): post feedforward norm. Default is None.
|
83
98
|
"""
|
84
99
|
super().__init__()
|
85
100
|
self.act = activation
|
86
|
-
|
101
|
+
if use_glu:
|
102
|
+
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
|
103
|
+
else:
|
104
|
+
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
87
105
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
88
106
|
self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
89
107
|
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
@@ -12,72 +12,184 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# `nn.Module` which implements a KV cache.
|
16
15
|
|
17
|
-
|
16
|
+
"""Utility functions for externalized KV Cache."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import List, Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch import hlfb
|
22
|
+
from ai_edge_torch.generative.layers import model_config
|
18
23
|
import torch
|
19
|
-
|
24
|
+
import torch.utils._pytree as pytree
|
20
25
|
|
26
|
+
BATCH_SIZE = 1
|
21
27
|
|
22
|
-
class KVCache(nn.Module):
|
23
28
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
"""Initializes the KVCache layer.
|
29
|
+
@dataclasses.dataclass
|
30
|
+
class KVCacheEntry:
|
31
|
+
"""A single cache entry that includes K and V caches.
|
28
32
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
n_heads (int): number of kv heads.
|
33
|
-
head_dim (int): the head dimension size.
|
34
|
-
enable_hlfb (bool): whether hlfb is enabled or not.
|
35
|
-
"""
|
36
|
-
super().__init__()
|
37
|
-
cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
|
38
|
-
self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
|
39
|
-
self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
|
40
|
-
self.enable_hlfb = enable_hlfb
|
41
|
-
self.kv_cache_max = kv_cache_max
|
33
|
+
The chaches are built based on the provided config with the shape of
|
34
|
+
(batch_size=1, kv_cache_max, num_query_groups, head_dim).
|
35
|
+
"""
|
42
36
|
|
43
|
-
|
44
|
-
|
37
|
+
k_cache: torch.Tensor
|
38
|
+
v_cache: torch.Tensor
|
45
39
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
40
|
+
@classmethod
|
41
|
+
def from_model_config(
|
42
|
+
cls,
|
43
|
+
kv_cache_max: int,
|
44
|
+
config: model_config.AttentionConfig,
|
45
|
+
dtype: torch.dtype = torch.float32,
|
46
|
+
device: torch.device = None,
|
47
|
+
) -> "KVCacheEntry":
|
48
|
+
"""Build an instance of the class based on model config."""
|
49
|
+
shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
|
50
|
+
k = torch.zeros(shape, dtype=dtype, device=device)
|
51
|
+
v = torch.zeros(shape, dtype=dtype, device=device)
|
52
|
+
obj = cls(k_cache=k, v_cache=v)
|
53
|
+
return obj
|
50
54
|
|
51
|
-
Returns:
|
52
|
-
The updated key and value tensor.
|
53
|
-
"""
|
54
|
-
if self.enable_hlfb:
|
55
|
-
return self.update_cache_with_hlfb(input_pos, k_val, v_val)
|
56
55
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
return torch.clone(updated_k), torch.clone(updated_v)
|
56
|
+
@dataclasses.dataclass
|
57
|
+
class KVCache:
|
58
|
+
"""A utility class for holding KV cache entries per layer."""
|
61
59
|
|
62
|
-
|
63
|
-
|
60
|
+
caches: Tuple[KVCacheEntry, ...]
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def from_model_config(
|
64
|
+
cls,
|
65
|
+
config: model_config.ModelConfig,
|
66
|
+
dtype: torch.dtype = torch.float32,
|
67
|
+
device: torch.device = None,
|
68
|
+
) -> "KVCache":
|
69
|
+
"""Build an instance of the class based on model config.
|
64
70
|
|
65
71
|
Args:
|
66
|
-
|
67
|
-
|
68
|
-
|
72
|
+
config (ModelConfig): Model config used for building the cache.
|
73
|
+
dtype (torch.dtype, optional): The data type of the cache tensor.
|
74
|
+
Defaults to torch.float32.
|
75
|
+
device (torch.device, optional): The device placement of the cache
|
76
|
+
tensors. Defaults to None.
|
69
77
|
|
70
78
|
Returns:
|
71
|
-
|
79
|
+
KVCache: The created cache object.
|
72
80
|
"""
|
81
|
+
caches = [
|
82
|
+
KVCacheEntry.from_model_config(
|
83
|
+
config.kv_cache_max,
|
84
|
+
config.block_config(idx).attn_config,
|
85
|
+
dtype,
|
86
|
+
device,
|
87
|
+
)
|
88
|
+
for idx in range(config.num_layers)
|
89
|
+
]
|
90
|
+
obj = cls(caches=tuple(caches))
|
91
|
+
return obj
|
73
92
|
|
74
|
-
|
75
|
-
|
76
|
-
)
|
77
|
-
|
78
|
-
|
93
|
+
def flatten(self) -> List[torch.Tensor]:
|
94
|
+
"""Flatten the cache entries into a list of tensors with order k_i, v_i."""
|
95
|
+
flattened, _ = _flatten_kvc(self)
|
96
|
+
return flattened
|
97
|
+
|
98
|
+
|
99
|
+
def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
|
100
|
+
flattened = []
|
101
|
+
flat_names = []
|
102
|
+
none_names = []
|
103
|
+
for i, kv_entry in enumerate(kvc.caches):
|
104
|
+
flattened.append(kv_entry.k_cache)
|
105
|
+
flat_names.append(f"k_{i}")
|
106
|
+
flattened.append(kv_entry.v_cache)
|
107
|
+
flat_names.append(f"v_{i}")
|
108
|
+
return flattened, [flat_names, none_names]
|
109
|
+
|
110
|
+
|
111
|
+
def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
|
112
|
+
flattened, (flat_names, none_names) = _flatten_kvc(kvc)
|
113
|
+
return [
|
114
|
+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
115
|
+
], flat_names
|
116
|
+
|
117
|
+
|
118
|
+
def _unflatten_kvc(
|
119
|
+
values: List[torch.Tensor], context: Tuple[List, List]
|
120
|
+
) -> KVCache:
|
121
|
+
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
122
|
+
num_layers = len(values) // 2
|
123
|
+
flat_names = context[0]
|
124
|
+
kv_entries = []
|
125
|
+
for i in range(num_layers):
|
126
|
+
k_cache_idx = flat_names.index(f"k_{i}")
|
127
|
+
v_cache_idx = flat_names.index(f"v_{i}")
|
128
|
+
kv_entries.append(
|
129
|
+
KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
|
79
130
|
)
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
131
|
+
obj = KVCache(tuple(kv_entries))
|
132
|
+
return obj
|
133
|
+
|
134
|
+
|
135
|
+
pytree.register_pytree_node(
|
136
|
+
KVCache,
|
137
|
+
_flatten_kvc,
|
138
|
+
_unflatten_kvc,
|
139
|
+
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
140
|
+
serialized_type_name="",
|
141
|
+
)
|
142
|
+
|
143
|
+
|
144
|
+
def update(
|
145
|
+
cache: KVCacheEntry,
|
146
|
+
input_pos: torch.Tensor,
|
147
|
+
k_slice: torch.Tensor,
|
148
|
+
v_slice: torch.Tensor,
|
149
|
+
enable_hlfb: bool = True,
|
150
|
+
) -> KVCacheEntry:
|
151
|
+
"""Out of place update of Cache buffer.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
cache (KVCacheEntry): The original cache buffer.
|
155
|
+
input_pos (torch.Tensor): The update slice positions.
|
156
|
+
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
157
|
+
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
158
|
+
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
159
|
+
High Level Function Boundary. Defaults to True.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
163
|
+
"""
|
164
|
+
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
|
165
|
+
return update_func(cache, input_pos, k_slice, v_slice)
|
166
|
+
|
167
|
+
|
168
|
+
def _update_kv_base_impl(
|
169
|
+
cache: KVCacheEntry,
|
170
|
+
input_pos: torch.Tensor,
|
171
|
+
k_slice: torch.Tensor,
|
172
|
+
v_slice: torch.Tensor,
|
173
|
+
) -> KVCacheEntry:
|
174
|
+
"""Update the cache buffer without High Level Function Boundary annotation."""
|
175
|
+
k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
|
176
|
+
v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
|
177
|
+
updated_cache = KVCacheEntry(k, v)
|
178
|
+
return updated_cache
|
179
|
+
|
180
|
+
|
181
|
+
def _update_kv_hlfb_impl(
|
182
|
+
cache: KVCacheEntry,
|
183
|
+
input_pos: torch.Tensor,
|
184
|
+
k_slice: torch.Tensor,
|
185
|
+
v_slice: torch.Tensor,
|
186
|
+
) -> KVCacheEntry:
|
187
|
+
"""Update the cache buffer with High Level Function Boundary annotation."""
|
188
|
+
builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
|
189
|
+
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
|
190
|
+
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
|
191
|
+
)
|
192
|
+
k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
|
193
|
+
v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
|
194
|
+
k, v = builder.mark_outputs(k, v)
|
195
|
+
return KVCacheEntry(k, v)
|
@@ -16,7 +16,7 @@
|
|
16
16
|
from dataclasses import dataclass
|
17
17
|
from dataclasses import field
|
18
18
|
import enum
|
19
|
-
from typing import Optional, Sequence
|
19
|
+
from typing import Optional, Sequence, Union
|
20
20
|
|
21
21
|
|
22
22
|
@enum.unique
|
@@ -30,6 +30,7 @@ class ActivationType(enum.Enum):
|
|
30
30
|
GELU_QUICK = enum.auto()
|
31
31
|
GE_GLU = enum.auto()
|
32
32
|
RELU = enum.auto()
|
33
|
+
SILU_GLU = enum.auto()
|
33
34
|
|
34
35
|
|
35
36
|
@enum.unique
|
@@ -58,6 +59,18 @@ class AttentionType(enum.Enum):
|
|
58
59
|
LOCAL_SLIDING = enum.auto()
|
59
60
|
|
60
61
|
|
62
|
+
@dataclass
|
63
|
+
class NormalizationConfig:
|
64
|
+
"""Normalizater parameters."""
|
65
|
+
|
66
|
+
type: NormalizationType = NormalizationType.NONE
|
67
|
+
enable_hlfb: bool = False
|
68
|
+
epsilon: float = 1e-5
|
69
|
+
zero_centered: bool = False
|
70
|
+
# Number of groups used in group normalization.
|
71
|
+
group_num: Optional[float] = None
|
72
|
+
|
73
|
+
|
61
74
|
@dataclass
|
62
75
|
class AttentionConfig:
|
63
76
|
"""Attention model's parameters."""
|
@@ -81,12 +94,20 @@ class AttentionConfig:
|
|
81
94
|
# Whether to use bias with attention output projection.
|
82
95
|
output_proj_use_bias: bool = False
|
83
96
|
enable_kv_cache: bool = True
|
97
|
+
# The normalization applied to query projection's output.
|
98
|
+
query_norm_config: NormalizationConfig = field(
|
99
|
+
default_factory=NormalizationConfig
|
100
|
+
)
|
101
|
+
# The normalization applied to key projection's output.
|
102
|
+
key_norm_config: NormalizationConfig = field(
|
103
|
+
default_factory=NormalizationConfig
|
104
|
+
)
|
84
105
|
relative_attention_num_buckets: int = 0
|
85
106
|
relative_attention_max_distance: int = 0
|
86
107
|
# Softcap on the output logits.
|
87
108
|
logit_softcap: Optional[float] = None
|
88
|
-
# The
|
89
|
-
|
109
|
+
# The type of attention.
|
110
|
+
attn_type: Optional[AttentionType] = None
|
90
111
|
# The size of the sliding window used for local attention.
|
91
112
|
sliding_window_size: Optional[int] = None
|
92
113
|
|
@@ -94,20 +115,9 @@ class AttentionConfig:
|
|
94
115
|
@dataclass
|
95
116
|
class ActivationConfig:
|
96
117
|
type: ActivationType = ActivationType.LINEAR
|
97
|
-
#
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
@dataclass
|
103
|
-
class NormalizationConfig:
|
104
|
-
"""Normalizater parameters."""
|
105
|
-
|
106
|
-
type: NormalizationType = NormalizationType.NONE
|
107
|
-
epsilon: float = 1e-5
|
108
|
-
zero_centered: bool = False
|
109
|
-
# Number of groups used in group normalization.
|
110
|
-
group_num: Optional[float] = None
|
118
|
+
# Whether to GLU gate is the front part instead of the back part of input
|
119
|
+
# when ActivationType is `GE_GLU` or `SILU_GLU`.
|
120
|
+
gate_is_front: bool = False
|
111
121
|
|
112
122
|
|
113
123
|
@dataclass
|
@@ -129,13 +139,8 @@ class FeedForwardConfig:
|
|
129
139
|
|
130
140
|
|
131
141
|
@dataclass
|
132
|
-
class
|
133
|
-
"""
|
134
|
-
|
135
|
-
vocab_size: int
|
136
|
-
num_layers: int
|
137
|
-
max_seq_len: int
|
138
|
-
embedding_dim: int
|
142
|
+
class TransformerBlockConfig:
|
143
|
+
"""TransformerBlock module's parameters."""
|
139
144
|
|
140
145
|
attn_config: AttentionConfig
|
141
146
|
ff_config: FeedForwardConfig
|
@@ -147,15 +152,33 @@ class ModelConfig:
|
|
147
152
|
post_attention_norm_config: NormalizationConfig = field(
|
148
153
|
default_factory=NormalizationConfig
|
149
154
|
)
|
155
|
+
# If set to True, only attn_config.pre_attention_norm is applied to the input
|
156
|
+
# and the decode's output is computed as `output = input + attn_out + ff_out`
|
157
|
+
# where attention and feed forward are called with pre_attention_norm's
|
158
|
+
# output.
|
159
|
+
parallel_residual: bool = False
|
160
|
+
# The Attention computation will include relative positional bias.
|
161
|
+
relative_attention: bool = False
|
162
|
+
|
163
|
+
|
164
|
+
@dataclass
|
165
|
+
class ModelConfig:
|
166
|
+
"""Base configurations for building a transformer architecture."""
|
167
|
+
|
168
|
+
vocab_size: int
|
169
|
+
num_layers: int
|
170
|
+
max_seq_len: int
|
171
|
+
embedding_dim: int
|
172
|
+
|
173
|
+
# TransformerBlockConfig for each layer block. If a single
|
174
|
+
# TransformerBlockConfig is provided, it will be used for all layers.
|
175
|
+
block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]
|
176
|
+
|
150
177
|
# The normalization applied before LM head.
|
151
178
|
final_norm_config: NormalizationConfig = field(
|
152
179
|
default_factory=NormalizationConfig
|
153
180
|
)
|
154
181
|
|
155
|
-
# If set to True, only pre_attention_norm is applied to the input and the
|
156
|
-
# decode's output is computed as `output = input + attn_out + ff_out` where
|
157
|
-
# attention and feed forward are called with pre_attention_norm's output.
|
158
|
-
parallel_residual: bool = False
|
159
182
|
# Use bias term within LLM's HEAD.
|
160
183
|
lm_head_use_bias: bool = False
|
161
184
|
# Whether to turn on high-level function boundary.
|
@@ -164,9 +187,6 @@ class ModelConfig:
|
|
164
187
|
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
|
165
188
|
kv_cache_max_len: int = 0
|
166
189
|
|
167
|
-
# The Attention computation will include relative positional bias.
|
168
|
-
relative_attention: bool = False
|
169
|
-
|
170
190
|
# Default batch size of the exported model. Default value is 1.
|
171
191
|
batch_size: int = 1
|
172
192
|
|
@@ -177,5 +197,13 @@ class ModelConfig:
|
|
177
197
|
def kv_cache_max(self) -> int:
|
178
198
|
if self.kv_cache_max_len > 0:
|
179
199
|
return self.kv_cache_max_len
|
180
|
-
|
181
|
-
|
200
|
+
return self.max_seq_len
|
201
|
+
|
202
|
+
def block_config(self, idx: int) -> TransformerBlockConfig:
|
203
|
+
if isinstance(self.block_configs, TransformerBlockConfig):
|
204
|
+
return self.block_configs
|
205
|
+
if idx < 0 or idx >= len(self.block_configs):
|
206
|
+
raise ValueError(
|
207
|
+
f"Index {idx} is out of range for layer configs: {self.block_configs}"
|
208
|
+
)
|
209
|
+
return self.block_configs[idx]
|
@@ -14,7 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Common normalization layers.
|
16
16
|
|
17
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
17
18
|
import torch
|
19
|
+
from torch import nn
|
20
|
+
import torch.nn.functional as F
|
18
21
|
|
19
22
|
|
20
23
|
# Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
|
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
|
|
58
61
|
return output * (1 + self.weight)
|
59
62
|
else:
|
60
63
|
return output * self.weight
|
64
|
+
|
65
|
+
|
66
|
+
class GroupNorm(torch.nn.Module):
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
group_num: int,
|
71
|
+
dim: int,
|
72
|
+
eps: float = 1e-5,
|
73
|
+
enable_hlfb: bool = False,
|
74
|
+
):
|
75
|
+
"""Initialize the GroupNorm layer.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
group_num (int): Number of groups to separate the channels into.
|
79
|
+
dim (int): Dimension of the input tensor.
|
80
|
+
eps (float): A small float value to ensure numerical stability (default:
|
81
|
+
1e-6).
|
82
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
83
|
+
op.
|
84
|
+
"""
|
85
|
+
super().__init__()
|
86
|
+
self.enable_hlfb = enable_hlfb
|
87
|
+
self.group_num = group_num
|
88
|
+
self.eps = eps
|
89
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
90
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
91
|
+
|
92
|
+
def forward(self, x):
|
93
|
+
"""Running the forward pass of GroupNorm layer.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
x (torch.Tensor): input tensor.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
torch.Tensor: output tensor after applying GroupNorm.
|
100
|
+
"""
|
101
|
+
if self.enable_hlfb:
|
102
|
+
return group_norm_with_hlfb(
|
103
|
+
x,
|
104
|
+
self.weight,
|
105
|
+
self.bias,
|
106
|
+
self.group_num,
|
107
|
+
self.eps,
|
108
|
+
)
|
109
|
+
else:
|
110
|
+
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
111
|
+
|
112
|
+
|
113
|
+
class LayerNorm(torch.nn.Module):
|
114
|
+
|
115
|
+
def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
|
116
|
+
"""Initialize the LayerNorm layer.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
dim (int): dimension of the input tensor.
|
120
|
+
eps (float): A small float value to ensure numerical stability (default:
|
121
|
+
1e-6).
|
122
|
+
enable_hlfb (bool): Whether to convert this normalization into a single
|
123
|
+
op.
|
124
|
+
"""
|
125
|
+
super().__init__()
|
126
|
+
self.enable_hlfb = enable_hlfb
|
127
|
+
self.eps = eps
|
128
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
129
|
+
self.bias = torch.nn.Parameter(torch.ones(dim))
|
130
|
+
|
131
|
+
def forward(self, x):
|
132
|
+
"""Running the forward pass of LayerNorm layer.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
x (torch.Tensor): input tensor.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
torch.Tensor: output tensor after applying LayerNorm.
|
139
|
+
"""
|
140
|
+
if self.enable_hlfb:
|
141
|
+
return layer_norm_with_hlfb(
|
142
|
+
x,
|
143
|
+
self.weight,
|
144
|
+
self.bias,
|
145
|
+
self.eps,
|
146
|
+
)
|
147
|
+
else:
|
148
|
+
return F.layer_norm(
|
149
|
+
x,
|
150
|
+
x.shape,
|
151
|
+
self.weight.broadcast_to(x.shape),
|
152
|
+
self.bias.broadcast_to(x.shape),
|
153
|
+
self.eps,
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
def group_norm_with_hlfb(
|
158
|
+
x: torch.Tensor,
|
159
|
+
w: torch.Tensor,
|
160
|
+
b: torch.Tensor,
|
161
|
+
num_groups: int,
|
162
|
+
eps: float,
|
163
|
+
):
|
164
|
+
"""Group Normalization with high-level function boundary enabled.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
|
168
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
169
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
170
|
+
num_groups (int): Number of groups to separate the channels into.
|
171
|
+
eps (float): A small float value to ensure numerical stability.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
The output tensor of Group Normalization.
|
175
|
+
"""
|
176
|
+
x = torch.permute(x, (0, 2, 3, 1))
|
177
|
+
|
178
|
+
builder = StableHLOCompositeBuilder(
|
179
|
+
name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
|
180
|
+
)
|
181
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
182
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
183
|
+
y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
|
184
|
+
y = torch.permute(y, (0, 2, 3, 1))
|
185
|
+
y = builder.mark_outputs(y)
|
186
|
+
|
187
|
+
y = torch.permute(y, (0, 3, 1, 2))
|
188
|
+
return y
|
189
|
+
|
190
|
+
|
191
|
+
def layer_norm_with_hlfb(
|
192
|
+
x: torch.Tensor,
|
193
|
+
w: torch.Tensor,
|
194
|
+
b: torch.Tensor,
|
195
|
+
eps: float,
|
196
|
+
):
|
197
|
+
"""Layer Normalization with high-level function boundary enabled.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
x (torch.Tensor): Input tensor for Layer Normalization.
|
201
|
+
w (torch.Tensor): The weight tensor for the normalization.
|
202
|
+
b (torch.Tensor): The bias tensor for the normalization.
|
203
|
+
eps (float): A small float value to ensure numerical stability.
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
The output tensor of Layer Normalization.
|
207
|
+
"""
|
208
|
+
builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
|
209
|
+
x, w, b = builder.mark_inputs(x, w, b)
|
210
|
+
y = F.layer_norm(
|
211
|
+
x,
|
212
|
+
x.shape,
|
213
|
+
weight=w.broadcast_to(x.shape),
|
214
|
+
bias=b.broadcast_to(x.shape),
|
215
|
+
eps=eps,
|
216
|
+
)
|
217
|
+
y = builder.mark_outputs(y)
|
218
|
+
return y
|
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
|
|
122
122
|
config.attention_batch_size,
|
123
123
|
config.dim,
|
124
124
|
config.attention_config,
|
125
|
-
0,
|
126
125
|
enable_hlfb=config.enable_hlfb,
|
127
126
|
)
|
128
127
|
|
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
|
|
180
179
|
config.query_dim,
|
181
180
|
config.cross_dim,
|
182
181
|
config.attention_config,
|
183
|
-
0,
|
184
182
|
enable_hlfb=config.enable_hlfb,
|
185
183
|
)
|
186
184
|
|
@@ -25,9 +25,9 @@ def main():
|
|
25
25
|
config = gemma.get_fake_model_config()
|
26
26
|
model = gemma.Gemma(config)
|
27
27
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
28
|
-
tokens = torch.full((1, 10), 0, dtype=torch.
|
28
|
+
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
29
29
|
tokens[0, :4] = idx
|
30
|
-
input_pos = torch.arange(0, 10)
|
30
|
+
input_pos = torch.arange(0, 10, dtype=torch.int)
|
31
31
|
|
32
32
|
# Create a quantization recipe to be applied to the model
|
33
33
|
quant_config = quant_recipes.full_int8_dynamic_recipe()
|