ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.8.0.dev20251225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/__init__.py +1 -0
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +97 -22
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +94 -2
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/RECORD +42 -36
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/top_level.txt +0 -0
|
@@ -133,10 +133,11 @@ def convert_signatures(
|
|
|
133
133
|
exported_program = fx_infra.safe_run_decompositions(
|
|
134
134
|
exported_program,
|
|
135
135
|
fx_infra.decomp.pre_convert_decomp(),
|
|
136
|
+
can_skip=False,
|
|
136
137
|
)
|
|
137
138
|
return exported_program
|
|
138
139
|
|
|
139
|
-
exported_programs
|
|
140
|
+
exported_programs = [
|
|
140
141
|
export(
|
|
141
142
|
mod=sig.module,
|
|
142
143
|
args=sig.args,
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
|
|
16
|
+
import operator
|
|
17
|
+
from typing import Any, Callable
|
|
16
18
|
import torch
|
|
17
19
|
|
|
18
20
|
|
|
@@ -26,8 +28,48 @@ _DUMMY_DECOMP_TABLE = {
|
|
|
26
28
|
torch._ops.OperatorBase(): lambda: None,
|
|
27
29
|
}
|
|
28
30
|
|
|
31
|
+
_BUILTIN_OPERATORS = {
|
|
32
|
+
getattr(operator, name)
|
|
33
|
+
for name in dir(operator)
|
|
34
|
+
if not name.startswith("_")
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _require_decomp(
|
|
39
|
+
exported_program: torch.export.ExportedProgram, decomp_table
|
|
40
|
+
):
|
|
41
|
+
"""Checks if the exported program requires decompositions."""
|
|
42
|
+
for node in exported_program.graph.nodes:
|
|
43
|
+
if "call_" not in str(node.op):
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
op = node.target
|
|
47
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
48
|
+
op = op.default
|
|
49
|
+
|
|
50
|
+
if op in decomp_table:
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
if (
|
|
54
|
+
not isinstance(op, (torch._ops.OpOverload, torch._ops.OperatorBase))
|
|
55
|
+
and op not in _BUILTIN_OPERATORS
|
|
56
|
+
):
|
|
57
|
+
# Python function that requires to be retraced via run_decompositions.
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
return False
|
|
29
61
|
|
|
30
|
-
|
|
62
|
+
|
|
63
|
+
_FORCE_DECOMP_ATTR = "_ai_edge_torch_force_decomp"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def annotate_force_decomp(decomp: Callable[..., Any]):
|
|
67
|
+
"""Annotates a decomp to force it to be run (at least shallowly) in safe_run_decompositions."""
|
|
68
|
+
setattr(decomp, _FORCE_DECOMP_ATTR, _FORCE_DECOMP_ATTR)
|
|
69
|
+
return decomp
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
|
|
31
73
|
"""Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
|
|
32
74
|
|
|
33
75
|
if decomp_table is not None and not decomp_table:
|
|
@@ -35,6 +77,9 @@ def safe_run_decompositions(exported_program, decomp_table=None):
|
|
|
35
77
|
# instead for backward compatibility.
|
|
36
78
|
decomp_table = _DUMMY_DECOMP_TABLE
|
|
37
79
|
|
|
80
|
+
if can_skip and not _require_decomp(exported_program, decomp_table):
|
|
81
|
+
return exported_program
|
|
82
|
+
|
|
38
83
|
for node in exported_program.graph.nodes:
|
|
39
84
|
if node.target == torch.ops.aten.view.default:
|
|
40
85
|
# Passes or torch.export may generate aten.view nodes not respecting the
|
|
@@ -44,6 +89,14 @@ def safe_run_decompositions(exported_program, decomp_table=None):
|
|
|
44
89
|
# back to one aten.view.
|
|
45
90
|
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
|
46
91
|
|
|
92
|
+
# Torch may skip some decompositions even if target is in decomp_table.
|
|
93
|
+
# The following ensures the target is always run through the decompositions
|
|
94
|
+
# shallowly if it has _FORCE_DECOMP_ATTR.
|
|
95
|
+
if decomp_table and node.target in decomp_table:
|
|
96
|
+
decomp = decomp_table[node.target]
|
|
97
|
+
if hasattr(decomp, _FORCE_DECOMP_ATTR):
|
|
98
|
+
node.target = decomp
|
|
99
|
+
|
|
47
100
|
exported_program = exported_program.run_decompositions(decomp_table)
|
|
48
101
|
|
|
49
102
|
if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
|
|
@@ -138,9 +138,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
138
138
|
if not os.path.exists(output_dir):
|
|
139
139
|
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
140
140
|
|
|
141
|
-
quant_config = (
|
|
142
|
-
quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
|
143
|
-
)
|
|
141
|
+
quant_config = quant_recipes.full_weight_only_recipe() if quantize else None
|
|
144
142
|
|
|
145
143
|
# TODO(yichunk): convert to multi signature tflite model.
|
|
146
144
|
# CLIP text encoder
|
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
import abc
|
|
19
19
|
from typing import Optional, Tuple, Union
|
|
20
20
|
|
|
21
|
+
from ai_edge_torch.generative.layers import attention_utils
|
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
|
22
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
23
24
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
|
@@ -240,13 +241,35 @@ class CausalSelfAttention(CausalSelfAttentionBase):
|
|
|
240
241
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
|
241
242
|
v = v.reshape(B, T, -1, self.config.head_dim)
|
|
242
243
|
|
|
243
|
-
|
|
244
|
+
alibi_bias = None
|
|
245
|
+
if self.config.use_alibi:
|
|
246
|
+
k_size = T
|
|
247
|
+
if mask is not None:
|
|
248
|
+
k_size = mask.shape[-1]
|
|
249
|
+
elif input_pos is not None:
|
|
250
|
+
# If mask is not present, assume current sequence length is key length.
|
|
251
|
+
k_size = input_pos[-1].item() + 1
|
|
252
|
+
alibi_bias = attention_utils.build_alibi_bias(
|
|
253
|
+
n_heads=self.config.num_heads,
|
|
254
|
+
k_size=k_size,
|
|
255
|
+
dtype=x.dtype,
|
|
256
|
+
device=x.device,
|
|
257
|
+
)
|
|
258
|
+
elif rope is not None:
|
|
244
259
|
# Compute rotary positional embedding for query and key.
|
|
245
260
|
cos, sin = rope
|
|
246
261
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
|
247
262
|
|
|
248
263
|
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
|
|
249
|
-
q,
|
|
264
|
+
q,
|
|
265
|
+
k,
|
|
266
|
+
v,
|
|
267
|
+
kv_cache,
|
|
268
|
+
input_pos,
|
|
269
|
+
mask,
|
|
270
|
+
self.config,
|
|
271
|
+
self.enable_hlfb,
|
|
272
|
+
alibi_bias=alibi_bias,
|
|
250
273
|
)
|
|
251
274
|
|
|
252
275
|
# Compute the output projection.
|
|
@@ -27,16 +27,27 @@ class AttentionTest(parameterized.TestCase):
|
|
|
27
27
|
dict(
|
|
28
28
|
testcase_name="local_causal_self_attention",
|
|
29
29
|
attn_type=cfg.AttentionType.LOCAL_SLIDING,
|
|
30
|
+
use_alibi=False,
|
|
30
31
|
expected_shape=(1, 10, 16),
|
|
31
32
|
),
|
|
32
33
|
dict(
|
|
33
34
|
testcase_name="global_causal_self_attention",
|
|
34
35
|
attn_type=cfg.AttentionType.GLOBAL,
|
|
36
|
+
use_alibi=False,
|
|
37
|
+
expected_shape=(1, 10, 16),
|
|
38
|
+
),
|
|
39
|
+
dict(
|
|
40
|
+
testcase_name="alibi_attention",
|
|
41
|
+
attn_type=cfg.AttentionType.GLOBAL,
|
|
42
|
+
use_alibi=True,
|
|
35
43
|
expected_shape=(1, 10, 16),
|
|
36
44
|
),
|
|
37
45
|
)
|
|
38
46
|
def test_causal_self_attention(
|
|
39
|
-
self,
|
|
47
|
+
self,
|
|
48
|
+
attn_type: cfg.AttentionType,
|
|
49
|
+
use_alibi: bool,
|
|
50
|
+
expected_shape: tuple[int, ...],
|
|
40
51
|
):
|
|
41
52
|
norm_config = cfg.NormalizationConfig(
|
|
42
53
|
type=cfg.NormalizationType.RMS_NORM,
|
|
@@ -56,6 +67,7 @@ class AttentionTest(parameterized.TestCase):
|
|
|
56
67
|
logit_softcap=None,
|
|
57
68
|
sliding_window_size=16,
|
|
58
69
|
attn_type=attn_type,
|
|
70
|
+
use_alibi=use_alibi,
|
|
59
71
|
)
|
|
60
72
|
self_atten = attention.CausalSelfAttention(
|
|
61
73
|
dim=16,
|
|
@@ -15,11 +15,72 @@
|
|
|
15
15
|
# Common utility functions used with attention module.
|
|
16
16
|
|
|
17
17
|
import math
|
|
18
|
-
from typing import Tuple
|
|
18
|
+
from typing import List, Tuple
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
def _get_alibi_slopes(n_heads: int) -> List[float]:
|
|
24
|
+
"""Returns slopes for ALiBi implementation.
|
|
25
|
+
|
|
26
|
+
The slopes are taken from the ALiBi paper
|
|
27
|
+
[https://arxiv.org/abs/2108.12409].
|
|
28
|
+
The slopes are later used to calculate the bias which is added to the
|
|
29
|
+
attention scores.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
n_heads (int): The number of attention heads.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def get_slopes_power_of_2(n):
|
|
36
|
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
37
|
+
return [start**i for i in range(1, n + 1)]
|
|
38
|
+
|
|
39
|
+
if math.log2(n_heads).is_integer():
|
|
40
|
+
return get_slopes_power_of_2(n_heads)
|
|
41
|
+
else:
|
|
42
|
+
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
|
|
43
|
+
return (
|
|
44
|
+
get_slopes_power_of_2(closest_power_of_2)
|
|
45
|
+
+ _get_alibi_slopes(2 * closest_power_of_2)[0::2][
|
|
46
|
+
: n_heads - closest_power_of_2
|
|
47
|
+
]
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def build_alibi_bias(
|
|
52
|
+
n_heads: int,
|
|
53
|
+
k_size: int,
|
|
54
|
+
dtype: torch.dtype = torch.float32,
|
|
55
|
+
device: torch.device = None,
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
"""Builds ALiBi bias tensor based on key position.
|
|
58
|
+
|
|
59
|
+
The bias tensor is added to the attention scores before softmax.
|
|
60
|
+
Replicates HuggingFace Falcon implementation behavior where bias only depends
|
|
61
|
+
on key position j, not relative position j-i.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
n_heads (int): The number of attention heads.
|
|
65
|
+
k_size (int): The key size of the bias tensor.
|
|
66
|
+
dtype (torch.dtype, optional): Output tensor's data type. Defaults to
|
|
67
|
+
torch.float32.
|
|
68
|
+
device (torch.device, optional): Output tensor's data type. Defaults to
|
|
69
|
+
None in which case "cpu" is used.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
torch.Tensor: The ALiBi bias tensor of shape (1, n_heads, 1, k_size).
|
|
73
|
+
"""
|
|
74
|
+
if device is None:
|
|
75
|
+
device = torch.device('cpu')
|
|
76
|
+
slopes = torch.tensor(_get_alibi_slopes(n_heads), dtype=dtype, device=device)
|
|
77
|
+
k_pos = torch.arange(k_size, device=device)
|
|
78
|
+
# According to HF implementation, bias only depends on key position.
|
|
79
|
+
# slopes[h] * k_pos[j]
|
|
80
|
+
alibi_bias = slopes.unsqueeze(-1) * k_pos.unsqueeze(0) # Shape: H, K
|
|
81
|
+
return alibi_bias[None, :, None, :].to(dtype)
|
|
82
|
+
|
|
83
|
+
|
|
23
84
|
def build_rope_cache(
|
|
24
85
|
size: int,
|
|
25
86
|
dim: int,
|
|
@@ -21,6 +21,26 @@ from absl.testing import absltest as googletest
|
|
|
21
21
|
|
|
22
22
|
class AttentionUtilsTest(googletest.TestCase):
|
|
23
23
|
|
|
24
|
+
def test_get_alibi_slopes(self):
|
|
25
|
+
slopes = attention_utils._get_alibi_slopes(1)
|
|
26
|
+
self.assertSequenceAlmostEqual(slopes, [0.00390625], places=6)
|
|
27
|
+
slopes = attention_utils._get_alibi_slopes(2)
|
|
28
|
+
self.assertSequenceAlmostEqual(slopes, [0.0625, 0.00390625], places=6)
|
|
29
|
+
slopes = attention_utils._get_alibi_slopes(4)
|
|
30
|
+
self.assertSequenceAlmostEqual(
|
|
31
|
+
slopes, [0.25, 0.0625, 0.015625, 0.00390625], places=6
|
|
32
|
+
)
|
|
33
|
+
slopes = attention_utils._get_alibi_slopes(3)
|
|
34
|
+
self.assertSequenceAlmostEqual(slopes, [0.0625, 0.00390625, 0.25], places=6)
|
|
35
|
+
|
|
36
|
+
def test_build_alibi_bias(self):
|
|
37
|
+
bias = attention_utils.build_alibi_bias(n_heads=2, k_size=3)
|
|
38
|
+
self.assertEqual(bias.shape, (1, 2, 1, 3))
|
|
39
|
+
expected = torch.tensor(
|
|
40
|
+
[[[[0.0, 0.0625, 0.125]], [[0.0, 0.00390625, 0.0078125]]]]
|
|
41
|
+
)
|
|
42
|
+
torch.testing.assert_close(bias, expected)
|
|
43
|
+
|
|
24
44
|
def test_build_causal_mask_cache(self):
|
|
25
45
|
mask = attention_utils.build_causal_mask_cache(3)
|
|
26
46
|
self.assertEqual(mask.shape, (1, 1, 3, 3))
|
|
@@ -71,7 +71,7 @@ def build_norm(
|
|
|
71
71
|
Raises:
|
|
72
72
|
ValueError: If config's `layer_norm_type` is not supported.
|
|
73
73
|
"""
|
|
74
|
-
if config.type == cfg.NormalizationType.NONE:
|
|
74
|
+
if config is None or config.type == cfg.NormalizationType.NONE:
|
|
75
75
|
return lambda x: x
|
|
76
76
|
elif config.type == cfg.NormalizationType.RMS_NORM:
|
|
77
77
|
return normalization.RMSNorm(
|
|
@@ -84,7 +84,9 @@ def build_norm(
|
|
|
84
84
|
init_fn=init_fn,
|
|
85
85
|
)
|
|
86
86
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
|
87
|
-
return normalization.LayerNorm(
|
|
87
|
+
return normalization.LayerNorm(
|
|
88
|
+
dim, config.epsilon, config.use_bias, config.enable_hlfb
|
|
89
|
+
)
|
|
88
90
|
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
|
89
91
|
return normalization.GroupNorm(
|
|
90
92
|
config.group_num, dim, config.epsilon, config.enable_hlfb
|
|
@@ -75,6 +75,8 @@ class NormalizationConfig:
|
|
|
75
75
|
scale_shift: float = 0.0
|
|
76
76
|
# Number of groups used in group normalization.
|
|
77
77
|
group_num: Optional[float] = None
|
|
78
|
+
# Whether to use bias in norm.
|
|
79
|
+
use_bias: bool = True
|
|
78
80
|
|
|
79
81
|
|
|
80
82
|
# Exprimental feature and may subject to change.
|
|
@@ -108,6 +110,8 @@ class AttentionConfig:
|
|
|
108
110
|
rotary_base: int = 10_000
|
|
109
111
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
|
110
112
|
rotary_percentage: Optional[float] = None
|
|
113
|
+
# Whether to use ALiBi positional encoding.
|
|
114
|
+
use_alibi: bool = False
|
|
111
115
|
# Whether to transpose the query groups of qkv bundled tensor before
|
|
112
116
|
# splitting into separated tensors.
|
|
113
117
|
qkv_transpose_before_split: bool = False
|
|
@@ -247,6 +251,7 @@ class ModelConfig:
|
|
|
247
251
|
lm_head_use_bias: bool = False
|
|
248
252
|
# Whether LLM's HEAD shares the weight of the embedding.
|
|
249
253
|
lm_head_share_weight_with_embedding: bool = True
|
|
254
|
+
dense_intermediate_size: Optional[int] = None
|
|
250
255
|
|
|
251
256
|
# Whether to turn on high-level function boundary.
|
|
252
257
|
enable_hlfb: bool = True
|
|
@@ -148,6 +148,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
148
148
|
self,
|
|
149
149
|
dim: int,
|
|
150
150
|
eps: float = 1e-5,
|
|
151
|
+
use_bias: bool = True,
|
|
151
152
|
enable_hlfb: bool = False,
|
|
152
153
|
):
|
|
153
154
|
"""Initialize the LayerNorm layer.
|
|
@@ -156,6 +157,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
156
157
|
dim (int): dimension of the input tensor.
|
|
157
158
|
eps (float): A small float value to ensure numerical stability (default:
|
|
158
159
|
1e-5).
|
|
160
|
+
use_bias (bool): Whether to use bias in LayerNorm.
|
|
159
161
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
|
160
162
|
op.
|
|
161
163
|
"""
|
|
@@ -164,7 +166,11 @@ class LayerNorm(torch.nn.Module):
|
|
|
164
166
|
self.normalized_shape = (dim,)
|
|
165
167
|
self.eps = eps
|
|
166
168
|
self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
|
167
|
-
self.bias =
|
|
169
|
+
self.bias = (
|
|
170
|
+
torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
|
171
|
+
if use_bias
|
|
172
|
+
else None
|
|
173
|
+
)
|
|
168
174
|
|
|
169
175
|
def forward(self, x):
|
|
170
176
|
"""Running the forward pass of LayerNorm layer.
|
|
@@ -175,7 +181,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
175
181
|
Returns:
|
|
176
182
|
torch.Tensor: output tensor after applying LayerNorm.
|
|
177
183
|
"""
|
|
178
|
-
if self.enable_hlfb:
|
|
184
|
+
if self.enable_hlfb and self.bias is not None:
|
|
179
185
|
return layer_norm_with_hlfb(
|
|
180
186
|
x, self.normalized_shape, self.weight, self.bias, self.eps
|
|
181
187
|
)
|
|
@@ -32,6 +32,7 @@ def scaled_dot_product_attention(
|
|
|
32
32
|
mask: Optional[torch.Tensor] = None,
|
|
33
33
|
scale: Optional[float] = None,
|
|
34
34
|
softcap: Optional[float] = None,
|
|
35
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
|
35
36
|
):
|
|
36
37
|
"""Scaled dot product attention.
|
|
37
38
|
|
|
@@ -41,14 +42,23 @@ def scaled_dot_product_attention(
|
|
|
41
42
|
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
42
43
|
head_size (int): head dimension.
|
|
43
44
|
mask (torch.Tensor): the optional mask tensor.
|
|
45
|
+
scale (float): the optional scale factor.
|
|
46
|
+
softcap (float): the optional softcap for the logits.
|
|
47
|
+
alibi_bias (torch.Tensor): optional alibi bias tensor.
|
|
44
48
|
|
|
45
49
|
Returns:
|
|
46
50
|
The output tensor of scaled_dot_product_attention.
|
|
47
51
|
"""
|
|
48
|
-
|
|
49
52
|
if scale is None:
|
|
50
53
|
scale = 1.0 / math.sqrt(head_size)
|
|
51
54
|
|
|
55
|
+
if alibi_bias is not None:
|
|
56
|
+
alibi_bias = alibi_bias * scale
|
|
57
|
+
if mask is None:
|
|
58
|
+
mask = alibi_bias
|
|
59
|
+
else:
|
|
60
|
+
mask = mask + alibi_bias
|
|
61
|
+
|
|
52
62
|
q = q.transpose(1, 2)
|
|
53
63
|
k = k.transpose(1, 2)
|
|
54
64
|
v = v.transpose(1, 2)
|
|
@@ -72,7 +82,8 @@ def scaled_dot_product_attention(
|
|
|
72
82
|
scores = scores / softcap
|
|
73
83
|
scores = torch.tanh(scores)
|
|
74
84
|
scores = scores * softcap
|
|
75
|
-
|
|
85
|
+
if mask is not None:
|
|
86
|
+
scores = scores + mask
|
|
76
87
|
out = F.softmax(scores.float(), dim=-1).type_as(q)
|
|
77
88
|
y = torch.matmul(out, v)
|
|
78
89
|
|
|
@@ -87,6 +98,7 @@ def scaled_dot_product_attention_with_hlfb(
|
|
|
87
98
|
mask: Optional[torch.Tensor] = None,
|
|
88
99
|
scale: Optional[float] = None,
|
|
89
100
|
softcap: Optional[float] = None,
|
|
101
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
|
90
102
|
):
|
|
91
103
|
"""Scaled dot product attention with high-level function boundary enabled.
|
|
92
104
|
|
|
@@ -96,14 +108,23 @@ def scaled_dot_product_attention_with_hlfb(
|
|
|
96
108
|
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
97
109
|
head_size (int): head dimension.
|
|
98
110
|
mask (torch.Tensor): the optional mask tensor.
|
|
111
|
+
scale (float): the optional scale factor.
|
|
112
|
+
softcap (float): the optional softcap for the logits.
|
|
113
|
+
alibi_bias (torch.Tensor): optional alibi bias tensor.
|
|
99
114
|
|
|
100
115
|
Returns:
|
|
101
116
|
The output tensor of scaled_dot_product_attention.
|
|
102
117
|
"""
|
|
103
|
-
|
|
104
118
|
if scale is None:
|
|
105
119
|
scale = 1.0 / math.sqrt(head_size)
|
|
106
120
|
|
|
121
|
+
if alibi_bias is not None:
|
|
122
|
+
alibi_bias = alibi_bias * scale
|
|
123
|
+
if mask is None:
|
|
124
|
+
mask = alibi_bias
|
|
125
|
+
else:
|
|
126
|
+
mask = mask + alibi_bias
|
|
127
|
+
|
|
107
128
|
attrs = {"scale": scale}
|
|
108
129
|
|
|
109
130
|
if softcap is not None:
|
|
@@ -137,7 +158,8 @@ def scaled_dot_product_attention_with_hlfb(
|
|
|
137
158
|
scores = scores / softcap
|
|
138
159
|
scores = torch.tanh(scores)
|
|
139
160
|
scores = scores * softcap
|
|
140
|
-
|
|
161
|
+
if mask is not None:
|
|
162
|
+
scores = scores + mask
|
|
141
163
|
out = F.softmax(scores.float(), dim=-1).type_as(q)
|
|
142
164
|
y = torch.matmul(out, v)
|
|
143
165
|
|
|
@@ -154,6 +176,7 @@ def scaled_dot_product_attention_transposed(
|
|
|
154
176
|
mask: Optional[torch.Tensor] = None,
|
|
155
177
|
scale: Optional[float] = None,
|
|
156
178
|
softcap: Optional[float] = None,
|
|
179
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
|
157
180
|
):
|
|
158
181
|
"""Scaled dot product attention with transposed key and value.
|
|
159
182
|
|
|
@@ -165,14 +188,21 @@ def scaled_dot_product_attention_transposed(
|
|
|
165
188
|
mask (torch.Tensor): the optional mask tensor.
|
|
166
189
|
scale (float): the optional scale factor.
|
|
167
190
|
softcap (float): the optional softcap for the logits.
|
|
191
|
+
alibi_bias (torch.Tensor): optional alibi bias tensor.
|
|
168
192
|
|
|
169
193
|
Returns:
|
|
170
194
|
The output tensor of scaled_dot_product_attention_transposed.
|
|
171
195
|
"""
|
|
172
|
-
|
|
173
196
|
if scale is None:
|
|
174
197
|
scale = 1.0 / math.sqrt(head_size)
|
|
175
198
|
|
|
199
|
+
if alibi_bias is not None:
|
|
200
|
+
alibi_bias = alibi_bias * scale
|
|
201
|
+
if mask is None:
|
|
202
|
+
mask = alibi_bias
|
|
203
|
+
else:
|
|
204
|
+
mask = mask + alibi_bias
|
|
205
|
+
|
|
176
206
|
query = query * scale
|
|
177
207
|
|
|
178
208
|
assert mask is not None, "Mask should not be None!"
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
"""Common utility functions for data loading etc."""
|
|
17
17
|
|
|
18
|
-
from typing import Tuple
|
|
18
|
+
from typing import Optional, Tuple
|
|
19
19
|
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
21
21
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
|
@@ -32,14 +32,15 @@ def sdpa_with_kv_update(
|
|
|
32
32
|
mask: torch.Tensor,
|
|
33
33
|
config: cfg.AttentionConfig,
|
|
34
34
|
enable_hlfb: bool,
|
|
35
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
|
35
36
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
|
36
37
|
"""Wrapper function for scaled dot product attention with KV cache update."""
|
|
37
38
|
if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED:
|
|
38
39
|
return _sdpa_with_kv_update_transposed(
|
|
39
|
-
query, key, value, kv, input_pos, mask, config
|
|
40
|
+
query, key, value, kv, input_pos, mask, config, alibi_bias
|
|
40
41
|
)
|
|
41
42
|
return _sdpa_with_kv_update_default(
|
|
42
|
-
query, key, value, kv, input_pos, mask, config, enable_hlfb
|
|
43
|
+
query, key, value, kv, input_pos, mask, config, enable_hlfb, alibi_bias
|
|
43
44
|
)
|
|
44
45
|
|
|
45
46
|
|
|
@@ -51,6 +52,7 @@ def _sdpa_with_kv_update_transposed(
|
|
|
51
52
|
input_pos: torch.Tensor,
|
|
52
53
|
mask: torch.Tensor,
|
|
53
54
|
config: cfg.AttentionConfig,
|
|
55
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
|
54
56
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
|
55
57
|
# Transpose k/v to specific layout for GPU implementation.
|
|
56
58
|
b, seq_len, n, h = query.shape
|
|
@@ -77,6 +79,7 @@ def _sdpa_with_kv_update_transposed(
|
|
|
77
79
|
config.head_dim,
|
|
78
80
|
mask=mask,
|
|
79
81
|
softcap=config.logit_softcap,
|
|
82
|
+
alibi_bias=alibi_bias,
|
|
80
83
|
) # 1, bk, gt, h
|
|
81
84
|
sdpa_out = (
|
|
82
85
|
sdpa_out.reshape(b, -1, seq_len, h)
|
|
@@ -95,6 +98,7 @@ def _sdpa_with_kv_update_default(
|
|
|
95
98
|
mask: torch.Tensor,
|
|
96
99
|
config: cfg.AttentionConfig,
|
|
97
100
|
enable_hlfb: bool,
|
|
101
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
|
98
102
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
|
99
103
|
b, seq_len, _, _ = query.shape
|
|
100
104
|
if kv is not None:
|
|
@@ -112,6 +116,7 @@ def _sdpa_with_kv_update_default(
|
|
|
112
116
|
config.head_dim,
|
|
113
117
|
mask=mask,
|
|
114
118
|
softcap=config.logit_softcap,
|
|
119
|
+
alibi_bias=alibi_bias,
|
|
115
120
|
)
|
|
116
121
|
sdpa_out = sdpa_out.reshape(b, seq_len, -1)
|
|
117
122
|
return sdpa_out, kv
|
|
@@ -33,7 +33,7 @@ def main():
|
|
|
33
33
|
kv = kv_utils.KVCache.from_model_config(config)
|
|
34
34
|
|
|
35
35
|
# Create a quantization recipe to be applied to the model
|
|
36
|
-
quant_config = quant_recipes.
|
|
36
|
+
quant_config = quant_recipes.full_dynamic_recipe()
|
|
37
37
|
print(quant_config)
|
|
38
38
|
|
|
39
39
|
# Convert with quantization
|
|
@@ -63,8 +63,15 @@ class Granularity(enum.Enum):
|
|
|
63
63
|
NONE: Granularity not applicable to this quantization scheme.
|
|
64
64
|
CHANNELWISE: Or per-channel quantization. Each channel of relevant tensors
|
|
65
65
|
is quantized independently of one another.
|
|
66
|
+
BLOCKWISE_32: Blockwise quantization with block size 32.
|
|
67
|
+
BLOCKWISE_64: Blockwise quantization with block size 64.
|
|
68
|
+
BLOCKWISE_128: Blockwise quantization with block size 128.
|
|
69
|
+
BLOCKWISE_256: Blockwise quantization with block size 256.
|
|
66
70
|
"""
|
|
67
71
|
|
|
68
72
|
NONE = enum.auto()
|
|
69
73
|
CHANNELWISE = enum.auto()
|
|
70
|
-
|
|
74
|
+
BLOCKWISE_32 = enum.auto()
|
|
75
|
+
BLOCKWISE_64 = enum.auto()
|
|
76
|
+
BLOCKWISE_128 = enum.auto()
|
|
77
|
+
BLOCKWISE_256 = enum.auto()
|
|
@@ -39,7 +39,6 @@ class LayerQuantRecipe:
|
|
|
39
39
|
mode: Type of quantization.
|
|
40
40
|
algorithm: Algorithm for calculating quantization parameters.
|
|
41
41
|
granularity: Granularity of quantization.
|
|
42
|
-
block_size: Size of the block for blockwise quantization.
|
|
43
42
|
"""
|
|
44
43
|
|
|
45
44
|
activation_dtype: quant_attrs.Dtype
|
|
@@ -47,7 +46,6 @@ class LayerQuantRecipe:
|
|
|
47
46
|
mode: quant_attrs.Mode
|
|
48
47
|
algorithm: quant_attrs.Algorithm
|
|
49
48
|
granularity: quant_attrs.Granularity
|
|
50
|
-
block_size: int = 0
|
|
51
49
|
|
|
52
50
|
def __str__(self):
|
|
53
51
|
base_str = (
|
|
@@ -56,7 +54,6 @@ class LayerQuantRecipe:
|
|
|
56
54
|
f'{self.mode.name}, '
|
|
57
55
|
f'{self.algorithm.name}, '
|
|
58
56
|
f'{self.granularity.name}, '
|
|
59
|
-
f'{self.block_size}'
|
|
60
57
|
)
|
|
61
58
|
return f'{base_str})'
|
|
62
59
|
|
|
@@ -77,16 +74,6 @@ class LayerQuantRecipe:
|
|
|
77
74
|
and self.algorithm == supported[3]
|
|
78
75
|
and self.granularity == supported[4]
|
|
79
76
|
):
|
|
80
|
-
if self.block_size > 0:
|
|
81
|
-
if (
|
|
82
|
-
self.block_size % 32 == 0
|
|
83
|
-
and self.granularity == quant_attrs.Granularity.BLOCKWISE
|
|
84
|
-
):
|
|
85
|
-
is_valid = True
|
|
86
|
-
break
|
|
87
|
-
else:
|
|
88
|
-
is_valid = False
|
|
89
|
-
break
|
|
90
77
|
is_valid = True
|
|
91
78
|
break
|
|
92
79
|
|