floydnet 0.1.1__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: floydnet
3
- Version: 0.1.1
3
+ Version: 1.0.0
4
4
  Summary: Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs
5
5
  Project-URL: Homepage, https://github.com/ocx-lab/FloydNet
6
6
  Project-URL: Repository, https://github.com/ocx-lab/FloydNet
@@ -231,21 +231,23 @@ Requires-Dist: ruff>=0.5; extra == 'dev'
231
231
  Description-Content-Type: text/markdown
232
232
 
233
233
  # FloydNet
234
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
235
+ [![Python](https://img.shields.io/badge/Python-3.9%2B-blue)](https://www.python.org/)
236
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.1%2B-orange)](https://pytorch.org/)
234
237
 
235
- Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).
238
+ Official implementation of [FloydNet](https://openreview.net/pdf?id=aUsx1G6RVQ).
236
239
 
237
240
  ![Figure Pivotal Attention Mechanism for 2-Floyd/3-Floyd.](misc/pivotalattn2&3.png)
238
241
 
239
242
  This repository serves two audiences:
240
-
241
- - **Engineering users**: reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
242
- - **Research users**: scripts/configs to reproduce paper experiments under `example/`.
243
+ - **Engineering users**: Reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
244
+ - **Research users**: Scripts/configs to reproduce paper experiments (TSP, Graph Isomorphism, BREC) under `example/`.
243
245
 
244
246
  ---
245
247
 
246
248
  ## Introduction
247
249
 
248
- FloydNet is the official PyTorch implementation accompanying an ICLR paper (TODO).
250
+ FloydNet is the official PyTorch implementation.
249
251
  The repository provides:
250
252
 
251
253
  1. **Reusable components**: a drop-in attention/Transformer-block interface intended for integration into existing projects.
@@ -267,10 +269,6 @@ For algorithmic details, hyperparameter choices, and analysis, please refer to t
267
269
 
268
270
  ---
269
271
 
270
- ## Using the Attention / Transformer Block
271
-
272
- This section targets **engineering users** who want to import FloydNet as a dependency.
273
-
274
272
  ### Installation
275
273
 
276
274
  #### Option A: Install from PyPI
@@ -376,7 +374,7 @@ If you use this code in your research, please cite the paper:
376
374
  @inproceedings{TODO,
377
375
  title = {TODO},
378
376
  author = {TODO},
379
- booktitle = {International Conference on Learning Representations (ICLR)},
377
+ booktitle = {TODO},
380
378
  year = {TODO},
381
379
  url = {TODO}
382
380
  }
@@ -1,19 +1,21 @@
1
1
  # FloydNet
2
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
3
+ [![Python](https://img.shields.io/badge/Python-3.9%2B-blue)](https://www.python.org/)
4
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.1%2B-orange)](https://pytorch.org/)
2
5
 
3
- Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).
6
+ Official implementation of [FloydNet](https://openreview.net/pdf?id=aUsx1G6RVQ).
4
7
 
5
8
  ![Figure Pivotal Attention Mechanism for 2-Floyd/3-Floyd.](misc/pivotalattn2&3.png)
6
9
 
7
10
  This repository serves two audiences:
8
-
9
- - **Engineering users**: reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
10
- - **Research users**: scripts/configs to reproduce paper experiments under `example/`.
11
+ - **Engineering users**: Reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
12
+ - **Research users**: Scripts/configs to reproduce paper experiments (TSP, Graph Isomorphism, BREC) under `example/`.
11
13
 
12
14
  ---
13
15
 
14
16
  ## Introduction
15
17
 
16
- FloydNet is the official PyTorch implementation accompanying an ICLR paper (TODO).
18
+ FloydNet is the official PyTorch implementation.
17
19
  The repository provides:
18
20
 
19
21
  1. **Reusable components**: a drop-in attention/Transformer-block interface intended for integration into existing projects.
@@ -35,10 +37,6 @@ For algorithmic details, hyperparameter choices, and analysis, please refer to t
35
37
 
36
38
  ---
37
39
 
38
- ## Using the Attention / Transformer Block
39
-
40
- This section targets **engineering users** who want to import FloydNet as a dependency.
41
-
42
40
  ### Installation
43
41
 
44
42
  #### Option A: Install from PyPI
@@ -144,7 +142,7 @@ If you use this code in your research, please cite the paper:
144
142
  @inproceedings{TODO,
145
143
  title = {TODO},
146
144
  author = {TODO},
147
- booktitle = {International Conference on Learning Representations (ICLR)},
145
+ booktitle = {TODO},
148
146
  year = {TODO},
149
147
  url = {TODO}
150
148
  }
@@ -6,6 +6,14 @@ The paper reports results on **three benchmarks**:
6
6
  - BREC
7
7
  - TSP
8
8
 
9
+ ## 🚀 Key Results
10
+
11
+ | Domain | Benchmark | Result |
12
+ | :--- | :--- | :--- |
13
+ | **Algorithmic** | CLRS-30 | **96.64%** (SOTA), effectively solving graph & string algorithms. |
14
+ | **Optimization** | Non-Metric TSP | **88.3%** optimality on unseen sizes ($N=200$), vs 1.3% for Linkern heuristic. |
15
+ | **Expressiveness** | Substructure Count | Near-zero error (MAE **0.001**) on complex substructure counting. |
16
+
9
17
  ### Graph Count
10
18
 
11
19
  The Graph Count benchmark and dataset construction follow:
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "floydnet"
7
- version = "0.1.1"
7
+ version = "1.0.0"
8
8
  description = "Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.9"
@@ -50,7 +50,7 @@ include = [
50
50
  "LICENSE",
51
51
  "CITATION.cff",
52
52
  "CHANGELOG.md",
53
- "paper/**",
53
+ "src/**",
54
54
  ]
55
55
 
56
56
  [tool.hatch.build.targets.wheel]
@@ -0,0 +1,8 @@
1
+ from .functional import pivotal_attention, pivotal_attention3
2
+ from .transformer import PivotalAttentionBlock
3
+
4
+ __all__ = [
5
+ "pivotal_attention",
6
+ "pivotal_attention3",
7
+ "PivotalAttentionBlock",
8
+ ]
@@ -0,0 +1,150 @@
1
+ # Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Optional
19
+ import math
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ def pivotal_attention(
25
+ q_ik: torch.Tensor,
26
+ k_ij: torch.Tensor,
27
+ k_jk: torch.Tensor,
28
+ v_ij: torch.Tensor,
29
+ v_jk: torch.Tensor,
30
+ attn_mask: Optional[torch.Tensor] = None,
31
+ dropout: float = 0.0,
32
+ scale: Optional[float] = None,
33
+ inf: float = 1e9,
34
+ ) -> torch.Tensor:
35
+ """Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
36
+
37
+ Shapes:
38
+ q_ik: (B, H, L_i, L_k, D)
39
+ k_ij: (B, H, L_i, L_j, D)
40
+ k_jk: (B, H, L_j, L_k, D)
41
+ v_ij: (B, H, L_i, L_j, D)
42
+ v_jk: (B, H, L_j, L_k, D)
43
+ attn_mask (optional): broadcastable to (B, H, L_i, L_k, L_j)
44
+
45
+ Args:
46
+ attn_mask: Additive mask (float) or boolean mask. If boolean, masked positions are set to -inf.
47
+ dropout: Dropout probability applied to attention weights (only effective if > 0).
48
+ scale: Optional custom scaling factor. If None, defaults to 1/sqrt(2*D).
49
+ inf: Value to use for -infinity in masks.
50
+
51
+ Returns:
52
+ Tensor of shape (B, H, L_i, L_k, D)
53
+ """
54
+ assert all([t.dim() == 5 for t in [q_ik, k_ij, k_jk, v_ij, v_jk]]), "All inputs must be 5D tensors"
55
+ B, H, L_i, L_k, D = q_ik.shape
56
+ L_j = k_ij.shape[3]
57
+ assert k_ij.shape == v_ij.shape == (B, H, L_i, L_j, D), "k_ij and v_ij must have shape (B, H, L_i, L_j, D)"
58
+ assert k_jk.shape == v_jk.shape == (B, H, L_j, L_k, D), "k_jk and v_jk must have shape (B, H, L_j, L_k, D)"
59
+
60
+ if scale is None:
61
+ scale = 1.0 / math.sqrt(2.0 * D)
62
+ q_ik = q_ik * scale
63
+
64
+ # Compute attention scores over the pivot dimension j: (B, H, L_i, L_k, L_j)
65
+ attn_scores = torch.einsum("bhikd,bhijd->bhikj", q_ik, k_ij) \
66
+ + torch.einsum("bhikd,bhjkd->bhikj", q_ik, k_jk)
67
+
68
+ if attn_mask is not None:
69
+ if attn_mask.dtype == torch.bool:
70
+ attn_scores = attn_scores.masked_fill(attn_mask, -inf)
71
+ else:
72
+ attn_scores = attn_scores + attn_mask
73
+
74
+ attn_weights = torch.softmax(attn_scores, dim=-1)
75
+
76
+ if dropout > 0.0:
77
+ attn_weights = F.dropout(attn_weights, p=dropout)
78
+
79
+ y = torch.einsum("bhikj,bhijd->bhikd", attn_weights, v_ij) \
80
+ + torch.einsum("bhikj,bhjkd->bhikd", attn_weights, v_jk)
81
+
82
+ return y
83
+
84
+ def pivotal_attention3(
85
+ q_ijk: torch.Tensor,
86
+ k_pjk: torch.Tensor,
87
+ k_ipk: torch.Tensor,
88
+ k_ijp: torch.Tensor,
89
+ v_pjk: torch.Tensor,
90
+ v_ipk: torch.Tensor,
91
+ v_ijp: torch.Tensor,
92
+ attn_mask: Optional[torch.Tensor] = None,
93
+ dropout: float = 0.0,
94
+ scale: Optional[float] = None,
95
+ inf: float = 1e9,
96
+ ) -> torch.Tensor:
97
+ """3-Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
98
+
99
+ Shapes:
100
+ q_ijk: (B, H, L_i, L_j, L_k, D)
101
+ k_pjk: (B, H, L_p, L_j, L_k, D)
102
+ k_ipk: (B, H, L_i, L_p, L_k, D)
103
+ k_ijp: (B, H, L_i, L_j, L_p, D)
104
+ v_pjk: (B, H, L_p, L_j, L_k, D)
105
+ v_ipk: (B, H, L_i, L_p, L_k, D)
106
+ v_ijp: (B, H, L_i, L_j, L_p, D)
107
+ attn_mask (optional): broadcastable to (B, H, L_i, L_j, L_k, L_p)
108
+
109
+ Args:
110
+ attn_mask: Additive mask (float) or boolean mask. If boolean, masked positions are set to -inf.
111
+ dropout: Dropout probability applied to attention weights (only effective if > 0).
112
+ scale: Optional custom scaling factor. If None, defaults to 1/sqrt(3*D).
113
+ inf: Value to use for -infinity in masks.
114
+
115
+ Returns:
116
+ Tensor of shape (B, H, L_i, l_j, L_k, D)
117
+ """
118
+ assert all([t.dim() == 6 for t in [q_ijk, k_pjk, k_ipk, k_ijp, v_pjk, v_ipk, v_ijp]]), "All inputs must be 6D tensors"
119
+ B, H, L_i, L_j, L_k, D = q_ijk.shape
120
+ L_p = k_pjk.shape[2]
121
+ assert k_pjk.shape == v_pjk.shape == (B, H, L_p, L_j, L_k, D), "k_pjk and v_pjk must have shape (B, H, L_p, L_j, L_k, D)"
122
+ assert k_ipk.shape == v_ipk.shape == (B, H, L_i, L_p, L_k, D), "k_ipk and v_ipk must have shape (B, H, L_i, L_p, L_k, D)"
123
+ assert k_ijp.shape == v_ijp.shape == (B, H, L_i, L_j, L_p, D), "k_ijp and v_ijp must have shape (B, H, L_i, L_j, L_p, D)"
124
+
125
+ if scale is None:
126
+ scale = 1.0 / math.sqrt(3.0 * D)
127
+ q_ijk = q_ijk * scale
128
+
129
+ # Compute attention scores over the pivot dimension j: (B, H, L_i, L_j, L_k, L_p)
130
+ attn_scores = torch.einsum("bhijkd,bhpjkd->bhijkp", q_ijk, k_pjk) \
131
+ + torch.einsum("bhijkd,bhipkd->bhijkp", q_ijk, k_ipk) \
132
+ + torch.einsum("bhijkd,bhijpd->bhijkp", q_ijk, k_ijp)
133
+
134
+ if attn_mask is not None:
135
+ if attn_mask.dtype == torch.bool:
136
+ attn_scores = attn_scores.masked_fill(attn_mask, -inf)
137
+ else:
138
+ attn_scores = attn_scores + attn_mask
139
+
140
+ attn_weights = torch.softmax(attn_scores, dim=-1)
141
+
142
+ if dropout > 0.0:
143
+ attn_weights = F.dropout(attn_weights, p=dropout)
144
+
145
+ y = torch.einsum("bhijkp,bhpjkd->bhijkd", attn_weights, v_pjk) \
146
+ + torch.einsum("bhijkp,bhipkd->bhijkd", attn_weights, v_ipk) \
147
+ + torch.einsum("bhijkp,bhijpd->bhijkd", attn_weights, v_ijp)
148
+
149
+ return y
150
+
@@ -0,0 +1,219 @@
1
+ # Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import copy
18
+ from typing import Optional, Tuple, Union, Callable
19
+
20
+ import torch
21
+ from torch import nn
22
+ from .functional import pivotal_attention
23
+
24
+
25
+ class Affine(nn.Module):
26
+ def __init__(self, c):
27
+ super().__init__()
28
+ self.weight = nn.Parameter(torch.ones((c, )))
29
+ self.bias = nn.Parameter(torch.zeros((c, )))
30
+
31
+ def forward(self, x: torch.Tensor):
32
+ return x * self.weight + self.bias
33
+
34
+
35
+ def create_norm(norm_fn: Union[str, Callable], embed_dim: int, eps: float = 1e-5, **kwargs) -> nn.Module:
36
+ """Create a normalization module from a name or nn.Module.
37
+
38
+ Args:
39
+ norm_fn: Name or an nn.Module instance/class.
40
+ embed_dim: Embedding dimension (features) used to construct the norm.
41
+ eps: Numerical epsilon passed to the normalization layer if applicable.
42
+ **kwargs: Extra keyword arguments forwarded to the normalization layer.
43
+
44
+ Returns:
45
+ An nn.Module normalization instance.
46
+ """
47
+ if isinstance(norm_fn, str):
48
+ if norm_fn.lower() in ["layernorm", "ln"]:
49
+ return nn.LayerNorm(embed_dim, eps=eps, **kwargs)
50
+ elif norm_fn.lower() in ["batchnorm", "bn"]:
51
+ return nn.BatchNorm1d(embed_dim, eps=eps, **kwargs)
52
+ elif norm_fn.lower() in ["rmsnorm", "rms"]:
53
+ return nn.RMSNorm(embed_dim, eps=eps, **kwargs)
54
+ elif norm_fn.lower() in ["affine"]:
55
+ return Affine(embed_dim)
56
+ elif norm_fn.lower() in ["none", "identity"]:
57
+ return nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unsupported norm_fn string: {norm_fn}")
60
+ elif callable(norm_fn):
61
+ if isinstance(norm_fn, nn.Module):
62
+ # deepcopy to avoid shared parameters
63
+ return copy.deepcopy(norm_fn)
64
+ elif isinstance(norm_fn, type) and issubclass(norm_fn, nn.Module):
65
+ return norm_fn(embed_dim, eps=eps, **kwargs)
66
+ else:
67
+ raise TypeError("norm_fn callable must be an nn.Module or nn.Module class")
68
+ else:
69
+ raise TypeError("norm_fn must be a string or callable")
70
+
71
+
72
+ def create_activation(activation_fn: Union[str, Callable]) -> nn.Module:
73
+ """Create an activation module from a name or nn.Module.
74
+
75
+ Args:
76
+ activation_fn: Name or an nn.Module instance/class.
77
+
78
+ Returns:
79
+ An nn.Module activation instance.
80
+ """
81
+ if isinstance(activation_fn, str):
82
+ if activation_fn.lower() == "relu":
83
+ return nn.ReLU()
84
+ elif activation_fn.lower() == "gelu":
85
+ return nn.GELU()
86
+ elif activation_fn.lower() == "silu":
87
+ return nn.SiLU()
88
+ else:
89
+ raise ValueError(f"Unsupported activation_fn string: {activation_fn}")
90
+ elif callable(activation_fn):
91
+ if isinstance(activation_fn, nn.Module):
92
+ return activation_fn
93
+ elif isinstance(activation_fn, type) and issubclass(activation_fn, nn.Module):
94
+ return activation_fn()
95
+ else:
96
+ raise TypeError("activation_fn callable must be an nn.Module or nn.Module class")
97
+ else:
98
+ raise TypeError("activation_fn must be a string or callable")
99
+
100
+
101
+ class PivotalAttentionBlock(nn.Module):
102
+ """Transformer-style block that applies pivotal attention followed by an FFN.
103
+
104
+ Args:
105
+ embed_dim: Input/hidden embedding dimension (D).
106
+ num_heads: Number of attention heads (D must be divisible by num_heads).
107
+ dropout: Dropout probability for attention output and FFN output.
108
+ bias: Whether to include bias terms in linear layers.
109
+ ffn_expansion_ratio: Expansion ratio for the FFN hidden size.
110
+ norm_position: "pre" or "post" layer normalization placement.
111
+ activation_fn: Activation name/module used in the FFN.
112
+ norm_fn: Normalization name/module used in the block.
113
+ """
114
+ def __init__(
115
+ self,
116
+ embed_dim: int,
117
+ num_heads: int,
118
+ dropout: float = 0.0,
119
+ bias: bool = False,
120
+ ffn_expansion_ratio: int = 4,
121
+ norm_position: str = "pre",
122
+ activation_fn: Union[str, Callable] = "relu",
123
+ norm_fn: Union[str, Callable] = "layernorm",
124
+ enable_symmetric_mix: bool = True,
125
+ enable_ffn: bool = True,
126
+ ) -> None:
127
+ super().__init__()
128
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
129
+ self.embed_dim = embed_dim
130
+ self.num_heads = num_heads
131
+ self.head_dim = embed_dim // num_heads
132
+ self.dropout = dropout
133
+ self.norm_position = norm_position.lower()
134
+ self.enable_ffn = enable_ffn
135
+ assert self.norm_position in ["pre", "post"], "norm_position must be 'pre' or 'post'"
136
+
137
+ self.enable_symmetric_mix = enable_symmetric_mix
138
+ if enable_symmetric_mix:
139
+ self.c_mix = nn.Linear(embed_dim, embed_dim, bias=bias)
140
+
141
+ self.c_qkv = nn.Linear(embed_dim, embed_dim * 5, bias=bias)
142
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
143
+ self.dropout_fn = nn.Dropout(dropout)
144
+ self.norm1 = create_norm(norm_fn, embed_dim)
145
+ if self.enable_ffn:
146
+ self.activation_fn = create_activation(activation_fn)
147
+ self.norm2 = create_norm(norm_fn, embed_dim)
148
+ self.ffn = nn.Sequential(
149
+ nn.Linear(embed_dim, ffn_expansion_ratio * embed_dim, bias=bias),
150
+ self.activation_fn,
151
+ nn.Linear(ffn_expansion_ratio * embed_dim, embed_dim, bias=bias),
152
+ nn.Dropout(dropout),
153
+ )
154
+ self.ffn_scale = nn.Parameter(torch.tensor(1.0, requires_grad=True))
155
+
156
+ self._reset_parameters()
157
+
158
+ def _reset_parameters(self) -> None:
159
+ """Initialize parameters using Xavier for projections and zeros for output heads."""
160
+ if self.enable_symmetric_mix:
161
+ nn.init.zeros_(self.c_mix.weight)
162
+ nn.init.xavier_uniform_(self.c_qkv.weight)
163
+ nn.init.zeros_(self.c_proj.weight)
164
+ if self.enable_ffn:
165
+ nn.init.xavier_uniform_(self.ffn[0].weight)
166
+ nn.init.zeros_(self.ffn[2].weight)
167
+ if self.c_qkv.bias is not None:
168
+ if self.enable_symmetric_mix:
169
+ nn.init.zeros_(self.c_mix.bias)
170
+ nn.init.zeros_(self.c_qkv.bias)
171
+ nn.init.zeros_(self.c_proj.bias)
172
+ if self.enable_ffn:
173
+ nn.init.zeros_(self.ffn[0].bias)
174
+ nn.init.zeros_(self.ffn[2].bias)
175
+
176
+ def attn(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]) -> torch.Tensor:
177
+ """Apply pivotal attention over a (L x L) grid.
178
+
179
+ Args:
180
+ x: Input tensor of shape (B, L, L, D).
181
+ attn_mask: Optional mask broadcastable to (B, H, L, L, L).
182
+
183
+ Returns:
184
+ Tensor of shape (B, L, L, D) after attention projection and dropout.
185
+ """
186
+ B, L, _, D = x.shape
187
+ # [B, L, L, 5*D] -> 5 x [B, H, L, L, d]
188
+ qkv = torch.chunk(self.c_qkv(x), 5, dim=-1)
189
+ q_ik, k_ij, k_jk, v_ij, v_jk = map(
190
+ lambda t: t.view(B, L, L, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4),
191
+ qkv,
192
+ )
193
+
194
+ # [B, H, L, L, d]
195
+ y = pivotal_attention(
196
+ q_ik, k_ij, k_jk, v_ij, v_jk,
197
+ attn_mask=attn_mask,
198
+ dropout=self.dropout if self.training else 0.0,
199
+ )
200
+ y = y.permute(0, 2, 3, 1, 4).contiguous().view(B, L, L, D)
201
+ y = self.c_proj(y)
202
+ y = self.dropout_fn(y)
203
+ return y
204
+
205
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ if self.enable_symmetric_mix:
207
+ xT = self.c_mix(x.transpose(1, 2))
208
+ else:
209
+ xT = 0
210
+ if self.norm_position == "pre":
211
+ x = x + self.attn(self.norm1(x + xT), attn_mask)
212
+ if self.enable_ffn:
213
+ x = x + self.ffn(self.norm2(x)) * self.ffn_scale
214
+ else:
215
+ x = self.norm1(x + self.attn(x + xT, attn_mask))
216
+ if self.enable_ffn:
217
+ x = self.norm2(x + self.ffn(x)) * self.ffn_scale
218
+
219
+ return x
File without changes
File without changes
File without changes
File without changes