floydnet 0.1.1__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {floydnet-0.1.1 → floydnet-0.1.2}/PKG-INFO +6 -8
- {floydnet-0.1.1 → floydnet-0.1.2}/README.md +5 -7
- {floydnet-0.1.1 → floydnet-0.1.2}/example/README.md +8 -0
- {floydnet-0.1.1 → floydnet-0.1.2}/pyproject.toml +2 -2
- floydnet-0.1.2/src/floydnet/__init__.py +8 -0
- floydnet-0.1.2/src/floydnet/functional.py +150 -0
- floydnet-0.1.2/src/floydnet/transformer.py +219 -0
- {floydnet-0.1.1 → floydnet-0.1.2}/.gitignore +0 -0
- {floydnet-0.1.1 → floydnet-0.1.2}/CHANGELOG.md +0 -0
- {floydnet-0.1.1 → floydnet-0.1.2}/CITATION.cff +0 -0
- {floydnet-0.1.1 → floydnet-0.1.2}/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: floydnet
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs
|
|
5
5
|
Project-URL: Homepage, https://github.com/ocx-lab/FloydNet
|
|
6
6
|
Project-URL: Repository, https://github.com/ocx-lab/FloydNet
|
|
@@ -231,15 +231,17 @@ Requires-Dist: ruff>=0.5; extra == 'dev'
|
|
|
231
231
|
Description-Content-Type: text/markdown
|
|
232
232
|
|
|
233
233
|
# FloydNet
|
|
234
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
235
|
+
[](https://www.python.org/)
|
|
236
|
+
[](https://pytorch.org/)
|
|
234
237
|
|
|
235
238
|
Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).
|
|
236
239
|
|
|
237
240
|

|
|
238
241
|
|
|
239
242
|
This repository serves two audiences:
|
|
240
|
-
|
|
241
|
-
- **
|
|
242
|
-
- **Research users**: scripts/configs to reproduce paper experiments under `example/`.
|
|
243
|
+
- **Engineering users**: Reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
|
|
244
|
+
- **Research users**: Scripts/configs to reproduce paper experiments (TSP, Graph Isomorphism, BREC) under `example/`.
|
|
243
245
|
|
|
244
246
|
---
|
|
245
247
|
|
|
@@ -267,10 +269,6 @@ For algorithmic details, hyperparameter choices, and analysis, please refer to t
|
|
|
267
269
|
|
|
268
270
|
---
|
|
269
271
|
|
|
270
|
-
## Using the Attention / Transformer Block
|
|
271
|
-
|
|
272
|
-
This section targets **engineering users** who want to import FloydNet as a dependency.
|
|
273
|
-
|
|
274
272
|
### Installation
|
|
275
273
|
|
|
276
274
|
#### Option A: Install from PyPI
|
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
# FloydNet
|
|
2
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
3
|
+
[](https://www.python.org/)
|
|
4
|
+
[](https://pytorch.org/)
|
|
2
5
|
|
|
3
6
|
Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).
|
|
4
7
|
|
|
5
8
|

|
|
6
9
|
|
|
7
10
|
This repository serves two audiences:
|
|
8
|
-
|
|
9
|
-
- **
|
|
10
|
-
- **Research users**: scripts/configs to reproduce paper experiments under `example/`.
|
|
11
|
+
- **Engineering users**: Reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
|
|
12
|
+
- **Research users**: Scripts/configs to reproduce paper experiments (TSP, Graph Isomorphism, BREC) under `example/`.
|
|
11
13
|
|
|
12
14
|
---
|
|
13
15
|
|
|
@@ -35,10 +37,6 @@ For algorithmic details, hyperparameter choices, and analysis, please refer to t
|
|
|
35
37
|
|
|
36
38
|
---
|
|
37
39
|
|
|
38
|
-
## Using the Attention / Transformer Block
|
|
39
|
-
|
|
40
|
-
This section targets **engineering users** who want to import FloydNet as a dependency.
|
|
41
|
-
|
|
42
40
|
### Installation
|
|
43
41
|
|
|
44
42
|
#### Option A: Install from PyPI
|
|
@@ -6,6 +6,14 @@ The paper reports results on **three benchmarks**:
|
|
|
6
6
|
- BREC
|
|
7
7
|
- TSP
|
|
8
8
|
|
|
9
|
+
## 🚀 Key Results
|
|
10
|
+
|
|
11
|
+
| Domain | Benchmark | Result |
|
|
12
|
+
| :--- | :--- | :--- |
|
|
13
|
+
| **Algorithmic** | CLRS-30 | **96.64%** (SOTA), effectively solving graph & string algorithms. |
|
|
14
|
+
| **Optimization** | Non-Metric TSP | **88.3%** optimality on unseen sizes ($N=200$), vs 1.3% for Linkern heuristic. |
|
|
15
|
+
| **Expressiveness** | Substructure Count | Near-zero error (MAE **0.001**) on complex substructure counting. |
|
|
16
|
+
|
|
9
17
|
### Graph Count
|
|
10
18
|
|
|
11
19
|
The Graph Count benchmark and dataset construction follow:
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "floydnet"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.2"
|
|
8
8
|
description = "Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.9"
|
|
@@ -50,7 +50,7 @@ include = [
|
|
|
50
50
|
"LICENSE",
|
|
51
51
|
"CITATION.cff",
|
|
52
52
|
"CHANGELOG.md",
|
|
53
|
-
"
|
|
53
|
+
"src/**",
|
|
54
54
|
]
|
|
55
55
|
|
|
56
56
|
[tool.hatch.build.targets.wheel]
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Optional
|
|
19
|
+
import math
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn.functional as F
|
|
23
|
+
|
|
24
|
+
def pivotal_attention(
|
|
25
|
+
q_ik: torch.Tensor,
|
|
26
|
+
k_ij: torch.Tensor,
|
|
27
|
+
k_jk: torch.Tensor,
|
|
28
|
+
v_ij: torch.Tensor,
|
|
29
|
+
v_jk: torch.Tensor,
|
|
30
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
31
|
+
dropout: float = 0.0,
|
|
32
|
+
scale: Optional[float] = None,
|
|
33
|
+
inf: float = 1e9,
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
"""Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
|
|
36
|
+
|
|
37
|
+
Shapes:
|
|
38
|
+
q_ik: (B, H, L_i, L_k, D)
|
|
39
|
+
k_ij: (B, H, L_i, L_j, D)
|
|
40
|
+
k_jk: (B, H, L_j, L_k, D)
|
|
41
|
+
v_ij: (B, H, L_i, L_j, D)
|
|
42
|
+
v_jk: (B, H, L_j, L_k, D)
|
|
43
|
+
attn_mask (optional): broadcastable to (B, H, L_i, L_k, L_j)
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
attn_mask: Additive mask (float) or boolean mask. If boolean, masked positions are set to -inf.
|
|
47
|
+
dropout: Dropout probability applied to attention weights (only effective if > 0).
|
|
48
|
+
scale: Optional custom scaling factor. If None, defaults to 1/sqrt(2*D).
|
|
49
|
+
inf: Value to use for -infinity in masks.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Tensor of shape (B, H, L_i, L_k, D)
|
|
53
|
+
"""
|
|
54
|
+
assert all([t.dim() == 5 for t in [q_ik, k_ij, k_jk, v_ij, v_jk]]), "All inputs must be 5D tensors"
|
|
55
|
+
B, H, L_i, L_k, D = q_ik.shape
|
|
56
|
+
L_j = k_ij.shape[3]
|
|
57
|
+
assert k_ij.shape == v_ij.shape == (B, H, L_i, L_j, D), "k_ij and v_ij must have shape (B, H, L_i, L_j, D)"
|
|
58
|
+
assert k_jk.shape == v_jk.shape == (B, H, L_j, L_k, D), "k_jk and v_jk must have shape (B, H, L_j, L_k, D)"
|
|
59
|
+
|
|
60
|
+
if scale is None:
|
|
61
|
+
scale = 1.0 / math.sqrt(2.0 * D)
|
|
62
|
+
q_ik = q_ik * scale
|
|
63
|
+
|
|
64
|
+
# Compute attention scores over the pivot dimension j: (B, H, L_i, L_k, L_j)
|
|
65
|
+
attn_scores = torch.einsum("bhikd,bhijd->bhikj", q_ik, k_ij) \
|
|
66
|
+
+ torch.einsum("bhikd,bhjkd->bhikj", q_ik, k_jk)
|
|
67
|
+
|
|
68
|
+
if attn_mask is not None:
|
|
69
|
+
if attn_mask.dtype == torch.bool:
|
|
70
|
+
attn_scores = attn_scores.masked_fill(attn_mask, -inf)
|
|
71
|
+
else:
|
|
72
|
+
attn_scores = attn_scores + attn_mask
|
|
73
|
+
|
|
74
|
+
attn_weights = torch.softmax(attn_scores, dim=-1)
|
|
75
|
+
|
|
76
|
+
if dropout > 0.0:
|
|
77
|
+
attn_weights = F.dropout(attn_weights, p=dropout)
|
|
78
|
+
|
|
79
|
+
y = torch.einsum("bhikj,bhijd->bhikd", attn_weights, v_ij) \
|
|
80
|
+
+ torch.einsum("bhikj,bhjkd->bhikd", attn_weights, v_jk)
|
|
81
|
+
|
|
82
|
+
return y
|
|
83
|
+
|
|
84
|
+
def pivotal_attention3(
|
|
85
|
+
q_ijk: torch.Tensor,
|
|
86
|
+
k_pjk: torch.Tensor,
|
|
87
|
+
k_ipk: torch.Tensor,
|
|
88
|
+
k_ijp: torch.Tensor,
|
|
89
|
+
v_pjk: torch.Tensor,
|
|
90
|
+
v_ipk: torch.Tensor,
|
|
91
|
+
v_ijp: torch.Tensor,
|
|
92
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
93
|
+
dropout: float = 0.0,
|
|
94
|
+
scale: Optional[float] = None,
|
|
95
|
+
inf: float = 1e9,
|
|
96
|
+
) -> torch.Tensor:
|
|
97
|
+
"""3-Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
|
|
98
|
+
|
|
99
|
+
Shapes:
|
|
100
|
+
q_ijk: (B, H, L_i, L_j, L_k, D)
|
|
101
|
+
k_pjk: (B, H, L_p, L_j, L_k, D)
|
|
102
|
+
k_ipk: (B, H, L_i, L_p, L_k, D)
|
|
103
|
+
k_ijp: (B, H, L_i, L_j, L_p, D)
|
|
104
|
+
v_pjk: (B, H, L_p, L_j, L_k, D)
|
|
105
|
+
v_ipk: (B, H, L_i, L_p, L_k, D)
|
|
106
|
+
v_ijp: (B, H, L_i, L_j, L_p, D)
|
|
107
|
+
attn_mask (optional): broadcastable to (B, H, L_i, L_j, L_k, L_p)
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
attn_mask: Additive mask (float) or boolean mask. If boolean, masked positions are set to -inf.
|
|
111
|
+
dropout: Dropout probability applied to attention weights (only effective if > 0).
|
|
112
|
+
scale: Optional custom scaling factor. If None, defaults to 1/sqrt(3*D).
|
|
113
|
+
inf: Value to use for -infinity in masks.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Tensor of shape (B, H, L_i, l_j, L_k, D)
|
|
117
|
+
"""
|
|
118
|
+
assert all([t.dim() == 6 for t in [q_ijk, k_pjk, k_ipk, k_ijp, v_pjk, v_ipk, v_ijp]]), "All inputs must be 6D tensors"
|
|
119
|
+
B, H, L_i, L_j, L_k, D = q_ijk.shape
|
|
120
|
+
L_p = k_pjk.shape[2]
|
|
121
|
+
assert k_pjk.shape == v_pjk.shape == (B, H, L_p, L_j, L_k, D), "k_pjk and v_pjk must have shape (B, H, L_p, L_j, L_k, D)"
|
|
122
|
+
assert k_ipk.shape == v_ipk.shape == (B, H, L_i, L_p, L_k, D), "k_ipk and v_ipk must have shape (B, H, L_i, L_p, L_k, D)"
|
|
123
|
+
assert k_ijp.shape == v_ijp.shape == (B, H, L_i, L_j, L_p, D), "k_ijp and v_ijp must have shape (B, H, L_i, L_j, L_p, D)"
|
|
124
|
+
|
|
125
|
+
if scale is None:
|
|
126
|
+
scale = 1.0 / math.sqrt(3.0 * D)
|
|
127
|
+
q_ijk = q_ijk * scale
|
|
128
|
+
|
|
129
|
+
# Compute attention scores over the pivot dimension j: (B, H, L_i, L_j, L_k, L_p)
|
|
130
|
+
attn_scores = torch.einsum("bhijkd,bhpjkd->bhijkp", q_ijk, k_pjk) \
|
|
131
|
+
+ torch.einsum("bhijkd,bhipkd->bhijkp", q_ijk, k_ipk) \
|
|
132
|
+
+ torch.einsum("bhijkd,bhijpd->bhijkp", q_ijk, k_ijp)
|
|
133
|
+
|
|
134
|
+
if attn_mask is not None:
|
|
135
|
+
if attn_mask.dtype == torch.bool:
|
|
136
|
+
attn_scores = attn_scores.masked_fill(attn_mask, -inf)
|
|
137
|
+
else:
|
|
138
|
+
attn_scores = attn_scores + attn_mask
|
|
139
|
+
|
|
140
|
+
attn_weights = torch.softmax(attn_scores, dim=-1)
|
|
141
|
+
|
|
142
|
+
if dropout > 0.0:
|
|
143
|
+
attn_weights = F.dropout(attn_weights, p=dropout)
|
|
144
|
+
|
|
145
|
+
y = torch.einsum("bhijkp,bhpjkd->bhijkd", attn_weights, v_pjk) \
|
|
146
|
+
+ torch.einsum("bhijkp,bhipkd->bhijkd", attn_weights, v_ipk) \
|
|
147
|
+
+ torch.einsum("bhijkp,bhijpd->bhijkd", attn_weights, v_ijp)
|
|
148
|
+
|
|
149
|
+
return y
|
|
150
|
+
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
from typing import Optional, Tuple, Union, Callable
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
from .functional import pivotal_attention
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Affine(nn.Module):
|
|
26
|
+
def __init__(self, c):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.weight = nn.Parameter(torch.ones((c, )))
|
|
29
|
+
self.bias = nn.Parameter(torch.zeros((c, )))
|
|
30
|
+
|
|
31
|
+
def forward(self, x: torch.Tensor):
|
|
32
|
+
return x * self.weight + self.bias
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def create_norm(norm_fn: Union[str, Callable], embed_dim: int, eps: float = 1e-5, **kwargs) -> nn.Module:
|
|
36
|
+
"""Create a normalization module from a name or nn.Module.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
norm_fn: Name or an nn.Module instance/class.
|
|
40
|
+
embed_dim: Embedding dimension (features) used to construct the norm.
|
|
41
|
+
eps: Numerical epsilon passed to the normalization layer if applicable.
|
|
42
|
+
**kwargs: Extra keyword arguments forwarded to the normalization layer.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
An nn.Module normalization instance.
|
|
46
|
+
"""
|
|
47
|
+
if isinstance(norm_fn, str):
|
|
48
|
+
if norm_fn.lower() in ["layernorm", "ln"]:
|
|
49
|
+
return nn.LayerNorm(embed_dim, eps=eps, **kwargs)
|
|
50
|
+
elif norm_fn.lower() in ["batchnorm", "bn"]:
|
|
51
|
+
return nn.BatchNorm1d(embed_dim, eps=eps, **kwargs)
|
|
52
|
+
elif norm_fn.lower() in ["rmsnorm", "rms"]:
|
|
53
|
+
return nn.RMSNorm(embed_dim, eps=eps, **kwargs)
|
|
54
|
+
elif norm_fn.lower() in ["affine"]:
|
|
55
|
+
return Affine(embed_dim)
|
|
56
|
+
elif norm_fn.lower() in ["none", "identity"]:
|
|
57
|
+
return nn.Identity()
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Unsupported norm_fn string: {norm_fn}")
|
|
60
|
+
elif callable(norm_fn):
|
|
61
|
+
if isinstance(norm_fn, nn.Module):
|
|
62
|
+
# deepcopy to avoid shared parameters
|
|
63
|
+
return copy.deepcopy(norm_fn)
|
|
64
|
+
elif isinstance(norm_fn, type) and issubclass(norm_fn, nn.Module):
|
|
65
|
+
return norm_fn(embed_dim, eps=eps, **kwargs)
|
|
66
|
+
else:
|
|
67
|
+
raise TypeError("norm_fn callable must be an nn.Module or nn.Module class")
|
|
68
|
+
else:
|
|
69
|
+
raise TypeError("norm_fn must be a string or callable")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def create_activation(activation_fn: Union[str, Callable]) -> nn.Module:
|
|
73
|
+
"""Create an activation module from a name or nn.Module.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
activation_fn: Name or an nn.Module instance/class.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
An nn.Module activation instance.
|
|
80
|
+
"""
|
|
81
|
+
if isinstance(activation_fn, str):
|
|
82
|
+
if activation_fn.lower() == "relu":
|
|
83
|
+
return nn.ReLU()
|
|
84
|
+
elif activation_fn.lower() == "gelu":
|
|
85
|
+
return nn.GELU()
|
|
86
|
+
elif activation_fn.lower() == "silu":
|
|
87
|
+
return nn.SiLU()
|
|
88
|
+
else:
|
|
89
|
+
raise ValueError(f"Unsupported activation_fn string: {activation_fn}")
|
|
90
|
+
elif callable(activation_fn):
|
|
91
|
+
if isinstance(activation_fn, nn.Module):
|
|
92
|
+
return activation_fn
|
|
93
|
+
elif isinstance(activation_fn, type) and issubclass(activation_fn, nn.Module):
|
|
94
|
+
return activation_fn()
|
|
95
|
+
else:
|
|
96
|
+
raise TypeError("activation_fn callable must be an nn.Module or nn.Module class")
|
|
97
|
+
else:
|
|
98
|
+
raise TypeError("activation_fn must be a string or callable")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class PivotalAttentionBlock(nn.Module):
|
|
102
|
+
"""Transformer-style block that applies pivotal attention followed by an FFN.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
embed_dim: Input/hidden embedding dimension (D).
|
|
106
|
+
num_heads: Number of attention heads (D must be divisible by num_heads).
|
|
107
|
+
dropout: Dropout probability for attention output and FFN output.
|
|
108
|
+
bias: Whether to include bias terms in linear layers.
|
|
109
|
+
ffn_expansion_ratio: Expansion ratio for the FFN hidden size.
|
|
110
|
+
norm_position: "pre" or "post" layer normalization placement.
|
|
111
|
+
activation_fn: Activation name/module used in the FFN.
|
|
112
|
+
norm_fn: Normalization name/module used in the block.
|
|
113
|
+
"""
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
embed_dim: int,
|
|
117
|
+
num_heads: int,
|
|
118
|
+
dropout: float = 0.0,
|
|
119
|
+
bias: bool = False,
|
|
120
|
+
ffn_expansion_ratio: int = 4,
|
|
121
|
+
norm_position: str = "pre",
|
|
122
|
+
activation_fn: Union[str, Callable] = "relu",
|
|
123
|
+
norm_fn: Union[str, Callable] = "layernorm",
|
|
124
|
+
enable_symmetric_mix: bool = True,
|
|
125
|
+
enable_ffn: bool = True,
|
|
126
|
+
) -> None:
|
|
127
|
+
super().__init__()
|
|
128
|
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
129
|
+
self.embed_dim = embed_dim
|
|
130
|
+
self.num_heads = num_heads
|
|
131
|
+
self.head_dim = embed_dim // num_heads
|
|
132
|
+
self.dropout = dropout
|
|
133
|
+
self.norm_position = norm_position.lower()
|
|
134
|
+
self.enable_ffn = enable_ffn
|
|
135
|
+
assert self.norm_position in ["pre", "post"], "norm_position must be 'pre' or 'post'"
|
|
136
|
+
|
|
137
|
+
self.enable_symmetric_mix = enable_symmetric_mix
|
|
138
|
+
if enable_symmetric_mix:
|
|
139
|
+
self.c_mix = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
140
|
+
|
|
141
|
+
self.c_qkv = nn.Linear(embed_dim, embed_dim * 5, bias=bias)
|
|
142
|
+
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
143
|
+
self.dropout_fn = nn.Dropout(dropout)
|
|
144
|
+
self.norm1 = create_norm(norm_fn, embed_dim)
|
|
145
|
+
if self.enable_ffn:
|
|
146
|
+
self.activation_fn = create_activation(activation_fn)
|
|
147
|
+
self.norm2 = create_norm(norm_fn, embed_dim)
|
|
148
|
+
self.ffn = nn.Sequential(
|
|
149
|
+
nn.Linear(embed_dim, ffn_expansion_ratio * embed_dim, bias=bias),
|
|
150
|
+
self.activation_fn,
|
|
151
|
+
nn.Linear(ffn_expansion_ratio * embed_dim, embed_dim, bias=bias),
|
|
152
|
+
nn.Dropout(dropout),
|
|
153
|
+
)
|
|
154
|
+
self.ffn_scale = nn.Parameter(torch.tensor(1.0, requires_grad=True))
|
|
155
|
+
|
|
156
|
+
self._reset_parameters()
|
|
157
|
+
|
|
158
|
+
def _reset_parameters(self) -> None:
|
|
159
|
+
"""Initialize parameters using Xavier for projections and zeros for output heads."""
|
|
160
|
+
if self.enable_symmetric_mix:
|
|
161
|
+
nn.init.zeros_(self.c_mix.weight)
|
|
162
|
+
nn.init.xavier_uniform_(self.c_qkv.weight)
|
|
163
|
+
nn.init.zeros_(self.c_proj.weight)
|
|
164
|
+
if self.enable_ffn:
|
|
165
|
+
nn.init.xavier_uniform_(self.ffn[0].weight)
|
|
166
|
+
nn.init.zeros_(self.ffn[2].weight)
|
|
167
|
+
if self.c_qkv.bias is not None:
|
|
168
|
+
if self.enable_symmetric_mix:
|
|
169
|
+
nn.init.zeros_(self.c_mix.bias)
|
|
170
|
+
nn.init.zeros_(self.c_qkv.bias)
|
|
171
|
+
nn.init.zeros_(self.c_proj.bias)
|
|
172
|
+
if self.enable_ffn:
|
|
173
|
+
nn.init.zeros_(self.ffn[0].bias)
|
|
174
|
+
nn.init.zeros_(self.ffn[2].bias)
|
|
175
|
+
|
|
176
|
+
def attn(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
|
177
|
+
"""Apply pivotal attention over a (L x L) grid.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
x: Input tensor of shape (B, L, L, D).
|
|
181
|
+
attn_mask: Optional mask broadcastable to (B, H, L, L, L).
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Tensor of shape (B, L, L, D) after attention projection and dropout.
|
|
185
|
+
"""
|
|
186
|
+
B, L, _, D = x.shape
|
|
187
|
+
# [B, L, L, 5*D] -> 5 x [B, H, L, L, d]
|
|
188
|
+
qkv = torch.chunk(self.c_qkv(x), 5, dim=-1)
|
|
189
|
+
q_ik, k_ij, k_jk, v_ij, v_jk = map(
|
|
190
|
+
lambda t: t.view(B, L, L, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4),
|
|
191
|
+
qkv,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# [B, H, L, L, d]
|
|
195
|
+
y = pivotal_attention(
|
|
196
|
+
q_ik, k_ij, k_jk, v_ij, v_jk,
|
|
197
|
+
attn_mask=attn_mask,
|
|
198
|
+
dropout=self.dropout if self.training else 0.0,
|
|
199
|
+
)
|
|
200
|
+
y = y.permute(0, 2, 3, 1, 4).contiguous().view(B, L, L, D)
|
|
201
|
+
y = self.c_proj(y)
|
|
202
|
+
y = self.dropout_fn(y)
|
|
203
|
+
return y
|
|
204
|
+
|
|
205
|
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
206
|
+
if self.enable_symmetric_mix:
|
|
207
|
+
xT = self.c_mix(x.transpose(1, 2))
|
|
208
|
+
else:
|
|
209
|
+
xT = 0
|
|
210
|
+
if self.norm_position == "pre":
|
|
211
|
+
x = x + self.attn(self.norm1(x + xT), attn_mask)
|
|
212
|
+
if self.enable_ffn:
|
|
213
|
+
x = x + self.ffn(self.norm2(x)) * self.ffn_scale
|
|
214
|
+
else:
|
|
215
|
+
x = self.norm1(x + self.attn(x + xT, attn_mask))
|
|
216
|
+
if self.enable_ffn:
|
|
217
|
+
x = self.norm2(x + self.ffn(x)) * self.ffn_scale
|
|
218
|
+
|
|
219
|
+
return x
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|