floydnet 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -205,3 +205,14 @@ cython_debug/
205
205
  marimo/_static/
206
206
  marimo/_lsp/
207
207
  __marimo__/
208
+
209
+ example/data/count/processed/
210
+ example/data/count/hom.npy
211
+ example/data/count/iso.npy
212
+ example/wandb
213
+ example/output
214
+ example/outputs
215
+ .github/
216
+ count.out
217
+ run_scripts
218
+ example/data/TSP
@@ -0,0 +1,12 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ ## [1.0.0] - 2026-01-25
6
+ - Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
7
+ - Added `pivotal_attention3` functional API for 3-Floyd attention.
8
+ - Added additional configuration options in `PivotalAttentionBlock`.
9
+
10
+ ## [0.1.0] - 2025-10-21
11
+ - Initial public skeleton with module + functional attention
12
+ - examples scaffolding
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: floydnet
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs
5
5
  Project-URL: Homepage, https://github.com/ocx-lab/FloydNet
6
6
  Project-URL: Repository, https://github.com/ocx-lab/FloydNet
@@ -230,49 +230,160 @@ Requires-Dist: pytest>=7.4; extra == 'dev'
230
230
  Requires-Dist: ruff>=0.5; extra == 'dev'
231
231
  Description-Content-Type: text/markdown
232
232
 
233
- # floyd-net
233
+ # FloydNet
234
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
235
+ [![Python](https://img.shields.io/badge/Python-3.9%2B-blue)](https://www.python.org/)
236
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.1%2B-orange)](https://pytorch.org/)
234
237
 
235
- Floyd Multi-Head Attention (F-MHA) is a drop-in variant of PyTorch's attention stack. It provides:
238
+ Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).
236
239
 
237
- - Module API: `FloydMultiheadAttention` mirroring `torch.nn.MultiheadAttention`
238
- - Functional API: `floyd_scaled_dot_product_attention` mirroring `torch.nn.functional.scaled_dot_product_attention`
240
+ ![Figure Pivotal Attention Mechanism for 2-Floyd/3-Floyd.](misc/pivotalattn2&3.png)
239
241
 
240
- Install and manage with `uv` for a modern Python workflow.
242
+ This repository serves two audiences:
243
+ - **Engineering users**: Reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
244
+ - **Research users**: Scripts/configs to reproduce paper experiments (TSP, Graph Isomorphism, BREC) under `example/`.
241
245
 
242
- ## Quick start
246
+ ---
243
247
 
248
+ ## Introduction
249
+
250
+ FloydNet is the official PyTorch implementation accompanying an ICLR paper (TODO).
251
+ The repository provides:
252
+
253
+ 1. **Reusable components**: a drop-in attention/Transformer-block interface intended for integration into existing projects.
254
+ 2. **Reproduction code**: end-to-end training/evaluation pipelines to reproduce the benchmarks reported in the paper.
255
+
256
+ For algorithmic details, hyperparameter choices, and analysis, please refer to the paper (TODO: link).
257
+
258
+ ---
259
+
260
+ ## Repository Structure
261
+
262
+ - `src/floydnet/`
263
+ **Library code for reuse**
264
+ Contains the functional attention API and module/block implementations.
265
+
266
+ - `example/`
267
+ **Experiment reproduction code**
268
+ Includes benchmark-specific scripts, configs, and data preparation utilities.
269
+
270
+ ---
271
+
272
+ ### Installation
273
+
274
+ #### Option A: Install from PyPI
244
275
  ```bash
245
- # Install with uv (recommended)
246
- uv venv --python 3.10
247
- source .venv/bin/activate
248
- uv pip install -e .[dev]
276
+ pip install floydnet
249
277
  ```
250
278
 
251
- ```python
252
- import torch
253
- from floyd_net import FloydMultiheadAttention
279
+ #### Option B: Install from source
280
+ ```bash
281
+ git clone git@github.com:ocx-lab/FloydNet.git
282
+ cd FloydNet
283
+ pip install -e .
284
+ ```
254
285
 
255
- m = FloydMultiheadAttention(embed_dim=64, num_heads=8, batch_first=True)
256
- x = torch.randn(2, 16, 64)
257
- out, attn = m(x, x, x)
258
- print(out.shape)
286
+ > Requirements: Python `>= 3.9`, PyTorch `>= 2.1` (see `pyproject.toml`).
287
+
288
+ ### Public API
289
+
290
+ FloydNet re-exports the public API from `src/floydnet/__init__.py`, so you can import from the top-level package:
291
+
292
+ - **Functional API**:
293
+ - `pivotal_attention` (see `src/floydnet/functional.py`)
294
+ - **Module / block API**:
295
+ - `PivotalAttentionBlock` (see `src/floydnet/transformer.py`)
296
+
297
+ ```python
298
+ from floydnet import pivotal_attention, PivotalAttentionBlock
259
299
  ```
260
300
 
261
- ### Functional API
301
+ ### Minimal usage example
302
+
262
303
  ```python
263
304
  import torch
264
- import torch.nn.functional as F
265
- from floyd_net import floyd_scaled_dot_product_attention
305
+ from floydnet import pivotal_attention, PivotalAttentionBlock
306
+
307
+ # -------------------------
308
+ # Module API (Transformer-style block)
309
+ # Input is a 2D grid: (B, N, N, C)
310
+ # -------------------------
311
+ B, N, C = 2, 16, 64
312
+ x = torch.randn(B, N, N, C)
266
313
 
267
- q = torch.randn(2, 8, 16, 64) # (B, H, L, D)
268
- k = torch.randn(2, 8, 16, 64)
269
- v = torch.randn(2, 8, 16, 64)
270
- out = floyd_scaled_dot_product_attention(q, k, v)
314
+ m = PivotalAttentionBlock(embed_dim=C, num_heads=8, dropout=0.0)
315
+ out = m(x) # (B, N, N, C)
271
316
  print(out.shape)
317
+
318
+ # -------------------------
319
+ # Functional API
320
+ # All inputs are 5D: (B, H, N, N, D)
321
+ # -------------------------
322
+ B, H, N, D = 2, 8, 16, 64
323
+ q_ik = torch.randn(B, H, N, N, D)
324
+ k_ij = torch.randn(B, H, N, N, D)
325
+ k_jk = torch.randn(B, H, N, N, D)
326
+ v_ij = torch.randn(B, H, N, N, D)
327
+ v_jk = torch.randn(B, H, N, N, D)
328
+
329
+ y = pivotal_attention(q_ik, k_ij, k_jk, v_ij, v_jk) # (B, H, N, N, D)
330
+ print(y.shape)
331
+ ```
332
+
333
+ ---
334
+
335
+ ## Reproducing Paper Results
336
+
337
+ This section targets **research users** who want to reproduce the experiments in the paper.
338
+
339
+ See `example/README.md` For detailed description.
340
+
341
+ ### Environment setup
342
+
343
+ We recommend using `uv` to create an isolated environment for the reproduction code under `example/`.
344
+
345
+ ```bash
346
+ cd /path/to/FloydNet
347
+
348
+ # 1) Create a uv virtual environment with Python 3.12
349
+ uv venv --python 3.12
350
+
351
+ # 2) Activate it
352
+ source .venv/bin/activate
353
+
354
+ # 3) Install extra dependencies for reproducing paper experiments
355
+ uv pip install -r example/requirements.txt
356
+
357
+ # 4) Install FloydNet (editable) for local development / imports
358
+ uv pip install -e .
359
+ ```
360
+
361
+ ## Changelog (latest)
362
+
363
+ - Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
364
+ - Added `pivotal_attention3` functional API for 3-Floyd attention.
365
+ - Added additional configuration options in `PivotalAttentionBlock`.
366
+
367
+ The full changelog is in [CHANGELOG.md](CHANGELOG.md).
368
+
369
+ ## Citation
370
+
371
+ If you use this code in your research, please cite the paper:
372
+
373
+ ```bibtex
374
+ @inproceedings{TODO,
375
+ title = {TODO},
376
+ author = {TODO},
377
+ booktitle = {International Conference on Learning Representations (ICLR)},
378
+ year = {TODO},
379
+ url = {TODO}
380
+ }
272
381
  ```
273
382
 
274
- ## Paper reproductions
275
- See `paper/` for dataset preparation, configs, and experiment templates to reproduce the results in the paper.
383
+ (Alternatively, see [CITATION.cff](CITATION.cff).)
384
+
385
+ ---
276
386
 
277
387
  ## License
278
- MIT
388
+
389
+ This project is licensed under the **Apache License 2.0**. See [LICENSE](LICENSE).
@@ -0,0 +1,157 @@
1
+ # FloydNet
2
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
3
+ [![Python](https://img.shields.io/badge/Python-3.9%2B-blue)](https://www.python.org/)
4
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.1%2B-orange)](https://pytorch.org/)
5
+
6
+ Official implementation of an ICLR paper (TODO: add paper title, authors, and links/arXiv).
7
+
8
+ ![Figure Pivotal Attention Mechanism for 2-Floyd/3-Floyd.](misc/pivotalattn2&3.png)
9
+
10
+ This repository serves two audiences:
11
+ - **Engineering users**: Reusable PyTorch components (functional attention APIs and Transformer-style blocks) under `src/`.
12
+ - **Research users**: Scripts/configs to reproduce paper experiments (TSP, Graph Isomorphism, BREC) under `example/`.
13
+
14
+ ---
15
+
16
+ ## Introduction
17
+
18
+ FloydNet is the official PyTorch implementation accompanying an ICLR paper (TODO).
19
+ The repository provides:
20
+
21
+ 1. **Reusable components**: a drop-in attention/Transformer-block interface intended for integration into existing projects.
22
+ 2. **Reproduction code**: end-to-end training/evaluation pipelines to reproduce the benchmarks reported in the paper.
23
+
24
+ For algorithmic details, hyperparameter choices, and analysis, please refer to the paper (TODO: link).
25
+
26
+ ---
27
+
28
+ ## Repository Structure
29
+
30
+ - `src/floydnet/`
31
+ **Library code for reuse**
32
+ Contains the functional attention API and module/block implementations.
33
+
34
+ - `example/`
35
+ **Experiment reproduction code**
36
+ Includes benchmark-specific scripts, configs, and data preparation utilities.
37
+
38
+ ---
39
+
40
+ ### Installation
41
+
42
+ #### Option A: Install from PyPI
43
+ ```bash
44
+ pip install floydnet
45
+ ```
46
+
47
+ #### Option B: Install from source
48
+ ```bash
49
+ git clone git@github.com:ocx-lab/FloydNet.git
50
+ cd FloydNet
51
+ pip install -e .
52
+ ```
53
+
54
+ > Requirements: Python `>= 3.9`, PyTorch `>= 2.1` (see `pyproject.toml`).
55
+
56
+ ### Public API
57
+
58
+ FloydNet re-exports the public API from `src/floydnet/__init__.py`, so you can import from the top-level package:
59
+
60
+ - **Functional API**:
61
+ - `pivotal_attention` (see `src/floydnet/functional.py`)
62
+ - **Module / block API**:
63
+ - `PivotalAttentionBlock` (see `src/floydnet/transformer.py`)
64
+
65
+ ```python
66
+ from floydnet import pivotal_attention, PivotalAttentionBlock
67
+ ```
68
+
69
+ ### Minimal usage example
70
+
71
+ ```python
72
+ import torch
73
+ from floydnet import pivotal_attention, PivotalAttentionBlock
74
+
75
+ # -------------------------
76
+ # Module API (Transformer-style block)
77
+ # Input is a 2D grid: (B, N, N, C)
78
+ # -------------------------
79
+ B, N, C = 2, 16, 64
80
+ x = torch.randn(B, N, N, C)
81
+
82
+ m = PivotalAttentionBlock(embed_dim=C, num_heads=8, dropout=0.0)
83
+ out = m(x) # (B, N, N, C)
84
+ print(out.shape)
85
+
86
+ # -------------------------
87
+ # Functional API
88
+ # All inputs are 5D: (B, H, N, N, D)
89
+ # -------------------------
90
+ B, H, N, D = 2, 8, 16, 64
91
+ q_ik = torch.randn(B, H, N, N, D)
92
+ k_ij = torch.randn(B, H, N, N, D)
93
+ k_jk = torch.randn(B, H, N, N, D)
94
+ v_ij = torch.randn(B, H, N, N, D)
95
+ v_jk = torch.randn(B, H, N, N, D)
96
+
97
+ y = pivotal_attention(q_ik, k_ij, k_jk, v_ij, v_jk) # (B, H, N, N, D)
98
+ print(y.shape)
99
+ ```
100
+
101
+ ---
102
+
103
+ ## Reproducing Paper Results
104
+
105
+ This section targets **research users** who want to reproduce the experiments in the paper.
106
+
107
+ See `example/README.md` For detailed description.
108
+
109
+ ### Environment setup
110
+
111
+ We recommend using `uv` to create an isolated environment for the reproduction code under `example/`.
112
+
113
+ ```bash
114
+ cd /path/to/FloydNet
115
+
116
+ # 1) Create a uv virtual environment with Python 3.12
117
+ uv venv --python 3.12
118
+
119
+ # 2) Activate it
120
+ source .venv/bin/activate
121
+
122
+ # 3) Install extra dependencies for reproducing paper experiments
123
+ uv pip install -r example/requirements.txt
124
+
125
+ # 4) Install FloydNet (editable) for local development / imports
126
+ uv pip install -e .
127
+ ```
128
+
129
+ ## Changelog (latest)
130
+
131
+ - Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
132
+ - Added `pivotal_attention3` functional API for 3-Floyd attention.
133
+ - Added additional configuration options in `PivotalAttentionBlock`.
134
+
135
+ The full changelog is in [CHANGELOG.md](CHANGELOG.md).
136
+
137
+ ## Citation
138
+
139
+ If you use this code in your research, please cite the paper:
140
+
141
+ ```bibtex
142
+ @inproceedings{TODO,
143
+ title = {TODO},
144
+ author = {TODO},
145
+ booktitle = {International Conference on Learning Representations (ICLR)},
146
+ year = {TODO},
147
+ url = {TODO}
148
+ }
149
+ ```
150
+
151
+ (Alternatively, see [CITATION.cff](CITATION.cff).)
152
+
153
+ ---
154
+
155
+ ## License
156
+
157
+ This project is licensed under the **Apache License 2.0**. See [LICENSE](LICENSE).
@@ -0,0 +1,137 @@
1
+ ### Benchmarks
2
+
3
+ The paper reports results on **three benchmarks**:
4
+
5
+ - Graph Count
6
+ - BREC
7
+ - TSP
8
+
9
+ ## 🚀 Key Results
10
+
11
+ | Domain | Benchmark | Result |
12
+ | :--- | :--- | :--- |
13
+ | **Algorithmic** | CLRS-30 | **96.64%** (SOTA), effectively solving graph & string algorithms. |
14
+ | **Optimization** | Non-Metric TSP | **88.3%** optimality on unseen sizes ($N=200$), vs 1.3% for Linkern heuristic. |
15
+ | **Expressiveness** | Substructure Count | Near-zero error (MAE **0.001**) on complex substructure counting. |
16
+
17
+ ### Graph Count
18
+
19
+ The Graph Count benchmark and dataset construction follow:
20
+ https://github.com/subgraph23/homomorphism-expressivity
21
+
22
+ ```bash
23
+ source .venv/bin/activate
24
+ cd example
25
+ python -m data.count.process_data
26
+ ./count/run.sh
27
+ ```
28
+
29
+ ### BREC
30
+
31
+ The BREC benchmark and dataset construction follow:
32
+ https://github.com/GraphPKU/BREC
33
+
34
+ ```bash
35
+ source .venv/bin/activate
36
+ cd example/data/BREC/raw
37
+ unzip BREC_data_all.zip
38
+
39
+ # Back to the example folder
40
+ cd ../../..
41
+
42
+ # Reproduce FloydNet results
43
+ python -m BREC.test_BREC
44
+
45
+ # Reproduce 3-Floyd results
46
+ python -m BREC.test_BREC --floyd_level 3
47
+ ```
48
+
49
+ ### TSP
50
+
51
+ Reproducing TSP at full scale is computationally heavy and involves large datasets. For convenience, we provide:
52
+
53
+ - A small demo dataset on Hugging Face:
54
+ https://huggingface.co/datasets/ocxlabs/FloydNet_TSP_demo
55
+ - Pretrained checkpoints for:
56
+ - **Metric TSP** (Euclidean TSP, `euc`): https://huggingface.co/ocxlabs/FloydNet_TSP_euc
57
+ - **Non-metric TSP** (Explicit TSP, `exp`): https://huggingface.co/ocxlabs/FloydNet_TSP_exp
58
+
59
+ This section describes **inference and evaluation** using the demo data and checkpoints.
60
+
61
+ #### Prepare demo data
62
+
63
+ Download the demo dataset as a `.zip`, unzip it, and place the extracted folder under `example/data/`.
64
+
65
+ #### Inference
66
+
67
+ Run inference in `--test_mode` using `torchrun` (the command below assumes **single-node, 8 GPUs**).
68
+ Set `--subset` and make sure `--load_checkpoint` matches the subset.
69
+
70
+ ```bash
71
+ source .venv/bin/activate
72
+ cd example
73
+
74
+ torchrun \
75
+ --nproc_per_node=8 \
76
+ -m TSP.run \
77
+ --subset exp \
78
+ --output_dir ./outputs/TSP_exp \
79
+ --load_checkpoint path/to/TSP_exp/epoch_01000.pt \
80
+ --test_mode \
81
+ --split_factor 1 \
82
+ --sample_count_per_case 10
83
+ ```
84
+
85
+ #### Evaluation
86
+
87
+ After inference finishes, aggregate results with:
88
+
89
+ ```bash
90
+ source .venv/bin/activate
91
+ cd example
92
+
93
+ python -m TSP.report ./outputs/TSP_exp
94
+ ```
95
+
96
+ This saves CSV summaries (downsampled to 1 / 5 / 10 samples per instance) into the same `output_dir`.
97
+
98
+ #### Data generation
99
+
100
+ If you want to generate additional data (beyond the demo set) and train from scratch, prepare the raw `.npy` files as follows.
101
+
102
+ ##### Metric TSP (Euclidean TSP, `euc`)
103
+
104
+ 1. Randomly sample **N integer points** in 2D, $p_i = (x_i, y_i)$, and ensure **pairwise Euclidean distances ≤ 200**.
105
+ 2. Solve the instance with a classic TSP solver (e.g., [Concorde](https://www.math.uwaterloo.ca/tsp/concorde.html)).
106
+ 3. Reorder the points so that $p_0 \rightarrow p_1 \rightarrow ... \rightarrow p_{N-1}$ is the optimal tour.
107
+ 4. Sample **T** instances for each **N**, stack them into a NumPy array of shape **`[T, N, 2]`** with dtype **`int8`**, and save grouped-by-N arrays as:
108
+ - `data/TSP/euc/non-uni/raw/{N:03d}.npy`
109
+
110
+ ##### Non-metric TSP (Explicit TSP, `exp`)
111
+
112
+ 1. Randomly sample an **N×N symmetric distance matrix** with **maximum value ≤ 200**.
113
+ 2. Solve with a classic TSP solver (e.g., [Concorde](https://www.math.uwaterloo.ca/tsp/concorde.html)).
114
+ 3. Reorder rows/columns so that $0 \rightarrow 1 \rightarrow ... \rightarrow N-1$ is the optimal tour.
115
+ 4. Sample **T** instances for each **N**, stack them into a NumPy array of shape **`[T, N, N]`** with dtype **`int8`**, and save grouped-by-N arrays as:
116
+ - `data/TSP/exp/non-uni/raw/{N:03d}.npy`
117
+
118
+ #### Training from scratch
119
+
120
+ Recommended (paper-matching) training setup: 8 nodes × 8 GPUs = 64 GPUs total.
121
+
122
+ ```bash
123
+ source .venv/bin/activate
124
+ cd example
125
+ torchrun \
126
+ --master_addr <MASTER_ADDR> \
127
+ --master_port <MASTER_PORT> \
128
+ --nnodes 8 \
129
+ --node_rank <NODE_RANK> \
130
+ --nproc_per_node 8 \
131
+ -m TSP.run \
132
+ --subset exp \
133
+ --output_dir ./outputs/TSP_exp \
134
+ --wandb_name TSP_exp
135
+ ```
136
+
137
+ ---
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "floydnet"
7
- version = "0.1.0"
7
+ version = "0.1.2"
8
8
  description = "Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.9"
@@ -50,11 +50,11 @@ include = [
50
50
  "LICENSE",
51
51
  "CITATION.cff",
52
52
  "CHANGELOG.md",
53
- "paper/**",
53
+ "src/**",
54
54
  ]
55
55
 
56
56
  [tool.hatch.build.targets.wheel]
57
- packages = ["src/floyd_net"]
57
+ packages = ["src/floydnet"]
58
58
 
59
59
  [tool.ruff]
60
60
  line-length = 100
@@ -0,0 +1,8 @@
1
+ from .functional import pivotal_attention, pivotal_attention3
2
+ from .transformer import PivotalAttentionBlock
3
+
4
+ __all__ = [
5
+ "pivotal_attention",
6
+ "pivotal_attention3",
7
+ "PivotalAttentionBlock",
8
+ ]
@@ -0,0 +1,150 @@
1
+ # Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Optional
19
+ import math
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ def pivotal_attention(
25
+ q_ik: torch.Tensor,
26
+ k_ij: torch.Tensor,
27
+ k_jk: torch.Tensor,
28
+ v_ij: torch.Tensor,
29
+ v_jk: torch.Tensor,
30
+ attn_mask: Optional[torch.Tensor] = None,
31
+ dropout: float = 0.0,
32
+ scale: Optional[float] = None,
33
+ inf: float = 1e9,
34
+ ) -> torch.Tensor:
35
+ """Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
36
+
37
+ Shapes:
38
+ q_ik: (B, H, L_i, L_k, D)
39
+ k_ij: (B, H, L_i, L_j, D)
40
+ k_jk: (B, H, L_j, L_k, D)
41
+ v_ij: (B, H, L_i, L_j, D)
42
+ v_jk: (B, H, L_j, L_k, D)
43
+ attn_mask (optional): broadcastable to (B, H, L_i, L_k, L_j)
44
+
45
+ Args:
46
+ attn_mask: Additive mask (float) or boolean mask. If boolean, masked positions are set to -inf.
47
+ dropout: Dropout probability applied to attention weights (only effective if > 0).
48
+ scale: Optional custom scaling factor. If None, defaults to 1/sqrt(2*D).
49
+ inf: Value to use for -infinity in masks.
50
+
51
+ Returns:
52
+ Tensor of shape (B, H, L_i, L_k, D)
53
+ """
54
+ assert all([t.dim() == 5 for t in [q_ik, k_ij, k_jk, v_ij, v_jk]]), "All inputs must be 5D tensors"
55
+ B, H, L_i, L_k, D = q_ik.shape
56
+ L_j = k_ij.shape[3]
57
+ assert k_ij.shape == v_ij.shape == (B, H, L_i, L_j, D), "k_ij and v_ij must have shape (B, H, L_i, L_j, D)"
58
+ assert k_jk.shape == v_jk.shape == (B, H, L_j, L_k, D), "k_jk and v_jk must have shape (B, H, L_j, L_k, D)"
59
+
60
+ if scale is None:
61
+ scale = 1.0 / math.sqrt(2.0 * D)
62
+ q_ik = q_ik * scale
63
+
64
+ # Compute attention scores over the pivot dimension j: (B, H, L_i, L_k, L_j)
65
+ attn_scores = torch.einsum("bhikd,bhijd->bhikj", q_ik, k_ij) \
66
+ + torch.einsum("bhikd,bhjkd->bhikj", q_ik, k_jk)
67
+
68
+ if attn_mask is not None:
69
+ if attn_mask.dtype == torch.bool:
70
+ attn_scores = attn_scores.masked_fill(attn_mask, -inf)
71
+ else:
72
+ attn_scores = attn_scores + attn_mask
73
+
74
+ attn_weights = torch.softmax(attn_scores, dim=-1)
75
+
76
+ if dropout > 0.0:
77
+ attn_weights = F.dropout(attn_weights, p=dropout)
78
+
79
+ y = torch.einsum("bhikj,bhijd->bhikd", attn_weights, v_ij) \
80
+ + torch.einsum("bhikj,bhjkd->bhikd", attn_weights, v_jk)
81
+
82
+ return y
83
+
84
+ def pivotal_attention3(
85
+ q_ijk: torch.Tensor,
86
+ k_pjk: torch.Tensor,
87
+ k_ipk: torch.Tensor,
88
+ k_ijp: torch.Tensor,
89
+ v_pjk: torch.Tensor,
90
+ v_ipk: torch.Tensor,
91
+ v_ijp: torch.Tensor,
92
+ attn_mask: Optional[torch.Tensor] = None,
93
+ dropout: float = 0.0,
94
+ scale: Optional[float] = None,
95
+ inf: float = 1e9,
96
+ ) -> torch.Tensor:
97
+ """3-Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
98
+
99
+ Shapes:
100
+ q_ijk: (B, H, L_i, L_j, L_k, D)
101
+ k_pjk: (B, H, L_p, L_j, L_k, D)
102
+ k_ipk: (B, H, L_i, L_p, L_k, D)
103
+ k_ijp: (B, H, L_i, L_j, L_p, D)
104
+ v_pjk: (B, H, L_p, L_j, L_k, D)
105
+ v_ipk: (B, H, L_i, L_p, L_k, D)
106
+ v_ijp: (B, H, L_i, L_j, L_p, D)
107
+ attn_mask (optional): broadcastable to (B, H, L_i, L_j, L_k, L_p)
108
+
109
+ Args:
110
+ attn_mask: Additive mask (float) or boolean mask. If boolean, masked positions are set to -inf.
111
+ dropout: Dropout probability applied to attention weights (only effective if > 0).
112
+ scale: Optional custom scaling factor. If None, defaults to 1/sqrt(3*D).
113
+ inf: Value to use for -infinity in masks.
114
+
115
+ Returns:
116
+ Tensor of shape (B, H, L_i, l_j, L_k, D)
117
+ """
118
+ assert all([t.dim() == 6 for t in [q_ijk, k_pjk, k_ipk, k_ijp, v_pjk, v_ipk, v_ijp]]), "All inputs must be 6D tensors"
119
+ B, H, L_i, L_j, L_k, D = q_ijk.shape
120
+ L_p = k_pjk.shape[2]
121
+ assert k_pjk.shape == v_pjk.shape == (B, H, L_p, L_j, L_k, D), "k_pjk and v_pjk must have shape (B, H, L_p, L_j, L_k, D)"
122
+ assert k_ipk.shape == v_ipk.shape == (B, H, L_i, L_p, L_k, D), "k_ipk and v_ipk must have shape (B, H, L_i, L_p, L_k, D)"
123
+ assert k_ijp.shape == v_ijp.shape == (B, H, L_i, L_j, L_p, D), "k_ijp and v_ijp must have shape (B, H, L_i, L_j, L_p, D)"
124
+
125
+ if scale is None:
126
+ scale = 1.0 / math.sqrt(3.0 * D)
127
+ q_ijk = q_ijk * scale
128
+
129
+ # Compute attention scores over the pivot dimension j: (B, H, L_i, L_j, L_k, L_p)
130
+ attn_scores = torch.einsum("bhijkd,bhpjkd->bhijkp", q_ijk, k_pjk) \
131
+ + torch.einsum("bhijkd,bhipkd->bhijkp", q_ijk, k_ipk) \
132
+ + torch.einsum("bhijkd,bhijpd->bhijkp", q_ijk, k_ijp)
133
+
134
+ if attn_mask is not None:
135
+ if attn_mask.dtype == torch.bool:
136
+ attn_scores = attn_scores.masked_fill(attn_mask, -inf)
137
+ else:
138
+ attn_scores = attn_scores + attn_mask
139
+
140
+ attn_weights = torch.softmax(attn_scores, dim=-1)
141
+
142
+ if dropout > 0.0:
143
+ attn_weights = F.dropout(attn_weights, p=dropout)
144
+
145
+ y = torch.einsum("bhijkp,bhpjkd->bhijkd", attn_weights, v_pjk) \
146
+ + torch.einsum("bhijkp,bhipkd->bhijkd", attn_weights, v_ipk) \
147
+ + torch.einsum("bhijkp,bhijpd->bhijkd", attn_weights, v_ijp)
148
+
149
+ return y
150
+
@@ -0,0 +1,219 @@
1
+ # Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import copy
18
+ from typing import Optional, Tuple, Union, Callable
19
+
20
+ import torch
21
+ from torch import nn
22
+ from .functional import pivotal_attention
23
+
24
+
25
+ class Affine(nn.Module):
26
+ def __init__(self, c):
27
+ super().__init__()
28
+ self.weight = nn.Parameter(torch.ones((c, )))
29
+ self.bias = nn.Parameter(torch.zeros((c, )))
30
+
31
+ def forward(self, x: torch.Tensor):
32
+ return x * self.weight + self.bias
33
+
34
+
35
+ def create_norm(norm_fn: Union[str, Callable], embed_dim: int, eps: float = 1e-5, **kwargs) -> nn.Module:
36
+ """Create a normalization module from a name or nn.Module.
37
+
38
+ Args:
39
+ norm_fn: Name or an nn.Module instance/class.
40
+ embed_dim: Embedding dimension (features) used to construct the norm.
41
+ eps: Numerical epsilon passed to the normalization layer if applicable.
42
+ **kwargs: Extra keyword arguments forwarded to the normalization layer.
43
+
44
+ Returns:
45
+ An nn.Module normalization instance.
46
+ """
47
+ if isinstance(norm_fn, str):
48
+ if norm_fn.lower() in ["layernorm", "ln"]:
49
+ return nn.LayerNorm(embed_dim, eps=eps, **kwargs)
50
+ elif norm_fn.lower() in ["batchnorm", "bn"]:
51
+ return nn.BatchNorm1d(embed_dim, eps=eps, **kwargs)
52
+ elif norm_fn.lower() in ["rmsnorm", "rms"]:
53
+ return nn.RMSNorm(embed_dim, eps=eps, **kwargs)
54
+ elif norm_fn.lower() in ["affine"]:
55
+ return Affine(embed_dim)
56
+ elif norm_fn.lower() in ["none", "identity"]:
57
+ return nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unsupported norm_fn string: {norm_fn}")
60
+ elif callable(norm_fn):
61
+ if isinstance(norm_fn, nn.Module):
62
+ # deepcopy to avoid shared parameters
63
+ return copy.deepcopy(norm_fn)
64
+ elif isinstance(norm_fn, type) and issubclass(norm_fn, nn.Module):
65
+ return norm_fn(embed_dim, eps=eps, **kwargs)
66
+ else:
67
+ raise TypeError("norm_fn callable must be an nn.Module or nn.Module class")
68
+ else:
69
+ raise TypeError("norm_fn must be a string or callable")
70
+
71
+
72
+ def create_activation(activation_fn: Union[str, Callable]) -> nn.Module:
73
+ """Create an activation module from a name or nn.Module.
74
+
75
+ Args:
76
+ activation_fn: Name or an nn.Module instance/class.
77
+
78
+ Returns:
79
+ An nn.Module activation instance.
80
+ """
81
+ if isinstance(activation_fn, str):
82
+ if activation_fn.lower() == "relu":
83
+ return nn.ReLU()
84
+ elif activation_fn.lower() == "gelu":
85
+ return nn.GELU()
86
+ elif activation_fn.lower() == "silu":
87
+ return nn.SiLU()
88
+ else:
89
+ raise ValueError(f"Unsupported activation_fn string: {activation_fn}")
90
+ elif callable(activation_fn):
91
+ if isinstance(activation_fn, nn.Module):
92
+ return activation_fn
93
+ elif isinstance(activation_fn, type) and issubclass(activation_fn, nn.Module):
94
+ return activation_fn()
95
+ else:
96
+ raise TypeError("activation_fn callable must be an nn.Module or nn.Module class")
97
+ else:
98
+ raise TypeError("activation_fn must be a string or callable")
99
+
100
+
101
+ class PivotalAttentionBlock(nn.Module):
102
+ """Transformer-style block that applies pivotal attention followed by an FFN.
103
+
104
+ Args:
105
+ embed_dim: Input/hidden embedding dimension (D).
106
+ num_heads: Number of attention heads (D must be divisible by num_heads).
107
+ dropout: Dropout probability for attention output and FFN output.
108
+ bias: Whether to include bias terms in linear layers.
109
+ ffn_expansion_ratio: Expansion ratio for the FFN hidden size.
110
+ norm_position: "pre" or "post" layer normalization placement.
111
+ activation_fn: Activation name/module used in the FFN.
112
+ norm_fn: Normalization name/module used in the block.
113
+ """
114
+ def __init__(
115
+ self,
116
+ embed_dim: int,
117
+ num_heads: int,
118
+ dropout: float = 0.0,
119
+ bias: bool = False,
120
+ ffn_expansion_ratio: int = 4,
121
+ norm_position: str = "pre",
122
+ activation_fn: Union[str, Callable] = "relu",
123
+ norm_fn: Union[str, Callable] = "layernorm",
124
+ enable_symmetric_mix: bool = True,
125
+ enable_ffn: bool = True,
126
+ ) -> None:
127
+ super().__init__()
128
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
129
+ self.embed_dim = embed_dim
130
+ self.num_heads = num_heads
131
+ self.head_dim = embed_dim // num_heads
132
+ self.dropout = dropout
133
+ self.norm_position = norm_position.lower()
134
+ self.enable_ffn = enable_ffn
135
+ assert self.norm_position in ["pre", "post"], "norm_position must be 'pre' or 'post'"
136
+
137
+ self.enable_symmetric_mix = enable_symmetric_mix
138
+ if enable_symmetric_mix:
139
+ self.c_mix = nn.Linear(embed_dim, embed_dim, bias=bias)
140
+
141
+ self.c_qkv = nn.Linear(embed_dim, embed_dim * 5, bias=bias)
142
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
143
+ self.dropout_fn = nn.Dropout(dropout)
144
+ self.norm1 = create_norm(norm_fn, embed_dim)
145
+ if self.enable_ffn:
146
+ self.activation_fn = create_activation(activation_fn)
147
+ self.norm2 = create_norm(norm_fn, embed_dim)
148
+ self.ffn = nn.Sequential(
149
+ nn.Linear(embed_dim, ffn_expansion_ratio * embed_dim, bias=bias),
150
+ self.activation_fn,
151
+ nn.Linear(ffn_expansion_ratio * embed_dim, embed_dim, bias=bias),
152
+ nn.Dropout(dropout),
153
+ )
154
+ self.ffn_scale = nn.Parameter(torch.tensor(1.0, requires_grad=True))
155
+
156
+ self._reset_parameters()
157
+
158
+ def _reset_parameters(self) -> None:
159
+ """Initialize parameters using Xavier for projections and zeros for output heads."""
160
+ if self.enable_symmetric_mix:
161
+ nn.init.zeros_(self.c_mix.weight)
162
+ nn.init.xavier_uniform_(self.c_qkv.weight)
163
+ nn.init.zeros_(self.c_proj.weight)
164
+ if self.enable_ffn:
165
+ nn.init.xavier_uniform_(self.ffn[0].weight)
166
+ nn.init.zeros_(self.ffn[2].weight)
167
+ if self.c_qkv.bias is not None:
168
+ if self.enable_symmetric_mix:
169
+ nn.init.zeros_(self.c_mix.bias)
170
+ nn.init.zeros_(self.c_qkv.bias)
171
+ nn.init.zeros_(self.c_proj.bias)
172
+ if self.enable_ffn:
173
+ nn.init.zeros_(self.ffn[0].bias)
174
+ nn.init.zeros_(self.ffn[2].bias)
175
+
176
+ def attn(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]) -> torch.Tensor:
177
+ """Apply pivotal attention over a (L x L) grid.
178
+
179
+ Args:
180
+ x: Input tensor of shape (B, L, L, D).
181
+ attn_mask: Optional mask broadcastable to (B, H, L, L, L).
182
+
183
+ Returns:
184
+ Tensor of shape (B, L, L, D) after attention projection and dropout.
185
+ """
186
+ B, L, _, D = x.shape
187
+ # [B, L, L, 5*D] -> 5 x [B, H, L, L, d]
188
+ qkv = torch.chunk(self.c_qkv(x), 5, dim=-1)
189
+ q_ik, k_ij, k_jk, v_ij, v_jk = map(
190
+ lambda t: t.view(B, L, L, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4),
191
+ qkv,
192
+ )
193
+
194
+ # [B, H, L, L, d]
195
+ y = pivotal_attention(
196
+ q_ik, k_ij, k_jk, v_ij, v_jk,
197
+ attn_mask=attn_mask,
198
+ dropout=self.dropout if self.training else 0.0,
199
+ )
200
+ y = y.permute(0, 2, 3, 1, 4).contiguous().view(B, L, L, D)
201
+ y = self.c_proj(y)
202
+ y = self.dropout_fn(y)
203
+ return y
204
+
205
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ if self.enable_symmetric_mix:
207
+ xT = self.c_mix(x.transpose(1, 2))
208
+ else:
209
+ xT = 0
210
+ if self.norm_position == "pre":
211
+ x = x + self.attn(self.norm1(x + xT), attn_mask)
212
+ if self.enable_ffn:
213
+ x = x + self.ffn(self.norm2(x)) * self.ffn_scale
214
+ else:
215
+ x = self.norm1(x + self.attn(x + xT, attn_mask))
216
+ if self.enable_ffn:
217
+ x = self.norm2(x + self.ffn(x)) * self.ffn_scale
218
+
219
+ return x
@@ -1,7 +0,0 @@
1
- # Changelog
2
-
3
- All notable changes to this project will be documented in this file.
4
-
5
- ## [0.1.0] - 2025-10-21
6
- - Initial public skeleton with module + functional attention
7
- - CI, tests, examples, paper scaffolding
floydnet-0.1.0/README.md DELETED
@@ -1,46 +0,0 @@
1
- # floyd-net
2
-
3
- Floyd Multi-Head Attention (F-MHA) is a drop-in variant of PyTorch's attention stack. It provides:
4
-
5
- - Module API: `FloydMultiheadAttention` mirroring `torch.nn.MultiheadAttention`
6
- - Functional API: `floyd_scaled_dot_product_attention` mirroring `torch.nn.functional.scaled_dot_product_attention`
7
-
8
- Install and manage with `uv` for a modern Python workflow.
9
-
10
- ## Quick start
11
-
12
- ```bash
13
- # Install with uv (recommended)
14
- uv venv --python 3.10
15
- source .venv/bin/activate
16
- uv pip install -e .[dev]
17
- ```
18
-
19
- ```python
20
- import torch
21
- from floyd_net import FloydMultiheadAttention
22
-
23
- m = FloydMultiheadAttention(embed_dim=64, num_heads=8, batch_first=True)
24
- x = torch.randn(2, 16, 64)
25
- out, attn = m(x, x, x)
26
- print(out.shape)
27
- ```
28
-
29
- ### Functional API
30
- ```python
31
- import torch
32
- import torch.nn.functional as F
33
- from floyd_net import floyd_scaled_dot_product_attention
34
-
35
- q = torch.randn(2, 8, 16, 64) # (B, H, L, D)
36
- k = torch.randn(2, 8, 16, 64)
37
- v = torch.randn(2, 8, 16, 64)
38
- out = floyd_scaled_dot_product_attention(q, k, v)
39
- print(out.shape)
40
- ```
41
-
42
- ## Paper reproductions
43
- See `paper/` for dataset preparation, configs, and experiment templates to reproduce the results in the paper.
44
-
45
- ## License
46
- MIT
@@ -1,11 +0,0 @@
1
- # Paper reproductions
2
-
3
- This folder contains materials to reproduce results from the paper.
4
-
5
- Structure:
6
- - `datasets/`: dataset preparation scripts or links
7
- - `configs/`: YAML/TOML experiment configs
8
- - `experiments/`: runnable training/eval scripts referencing configs
9
- - `notebooks/`: exploratory notebooks (optional)
10
-
11
- Use `uv` to create an environment, then run scripts inside `experiments/`.
File without changes
File without changes