floydnet 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
floydnet/functional.py CHANGED
@@ -31,6 +31,7 @@ def pivotal_attention(
31
31
  dropout: float = 0.0,
32
32
  scale: Optional[float] = None,
33
33
  inf: float = 1e9,
34
+ softmax_cap: float = -1,
34
35
  ) -> torch.Tensor:
35
36
  """Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
36
37
 
@@ -47,6 +48,9 @@ def pivotal_attention(
47
48
  dropout: Dropout probability applied to attention weights (only effective if > 0).
48
49
  scale: Optional custom scaling factor. If None, defaults to 1/sqrt(2*D).
49
50
  inf: Value to use for -infinity in masks.
51
+ softmax_cap: If > 0, applies a tanh-based logit cap before softmax.
52
+ Note: when using a non-boolean (additive) attn_mask, ensure its magnitude/semantics remain compatible
53
+ with capping (e.g., very large negative values used to approximate -inf can interact with logit shaping).
50
54
 
51
55
  Returns:
52
56
  Tensor of shape (B, H, L_i, L_k, D)
@@ -65,6 +69,9 @@ def pivotal_attention(
65
69
  attn_scores = torch.einsum("bhikd,bhijd->bhikj", q_ik, k_ij) \
66
70
  + torch.einsum("bhikd,bhjkd->bhikj", q_ik, k_jk)
67
71
 
72
+ if softmax_cap > 0:
73
+ attn_scores = softmax_cap * torch.tanh(attn_scores / softmax_cap)
74
+
68
75
  if attn_mask is not None:
69
76
  if attn_mask.dtype == torch.bool:
70
77
  attn_scores = attn_scores.masked_fill(attn_mask, -inf)
@@ -93,6 +100,7 @@ def pivotal_attention3(
93
100
  dropout: float = 0.0,
94
101
  scale: Optional[float] = None,
95
102
  inf: float = 1e9,
103
+ softmax_cap: float = -1,
96
104
  ) -> torch.Tensor:
97
105
  """3-Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING".
98
106
 
@@ -111,9 +119,12 @@ def pivotal_attention3(
111
119
  dropout: Dropout probability applied to attention weights (only effective if > 0).
112
120
  scale: Optional custom scaling factor. If None, defaults to 1/sqrt(3*D).
113
121
  inf: Value to use for -infinity in masks.
122
+ softmax_cap: If > 0, applies a tanh-based logit cap before softmax.
123
+ Note: when using a non-boolean (additive) attn_mask, ensure its magnitude/semantics remain compatible
124
+ with capping (e.g., very large negative values used to approximate -inf can interact with logit shaping).
114
125
 
115
126
  Returns:
116
- Tensor of shape (B, H, L_i, l_j, L_k, D)
127
+ Tensor of shape (B, H, L_i, L_j, L_k, D)
117
128
  """
118
129
  assert all([t.dim() == 6 for t in [q_ijk, k_pjk, k_ipk, k_ijp, v_pjk, v_ipk, v_ijp]]), "All inputs must be 6D tensors"
119
130
  B, H, L_i, L_j, L_k, D = q_ijk.shape
@@ -130,6 +141,9 @@ def pivotal_attention3(
130
141
  attn_scores = torch.einsum("bhijkd,bhpjkd->bhijkp", q_ijk, k_pjk) \
131
142
  + torch.einsum("bhijkd,bhipkd->bhijkp", q_ijk, k_ipk) \
132
143
  + torch.einsum("bhijkd,bhijpd->bhijkp", q_ijk, k_ijp)
144
+
145
+ if softmax_cap > 0:
146
+ attn_scores = softmax_cap * torch.tanh(attn_scores / softmax_cap)
133
147
 
134
148
  if attn_mask is not None:
135
149
  if attn_mask.dtype == torch.bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: floydnet
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs
5
5
  Project-URL: Homepage, https://github.com/ocx-lab/FloydNet
6
6
  Project-URL: Repository, https://github.com/ocx-lab/FloydNet
@@ -235,7 +235,7 @@ Description-Content-Type: text/markdown
235
235
  [![Python](https://img.shields.io/badge/Python-3.9%2B-blue)](https://www.python.org/)
236
236
  [![PyTorch](https://img.shields.io/badge/PyTorch-2.1%2B-orange)](https://pytorch.org/)
237
237
 
238
- Official implementation of [FloydNet](https://openreview.net/pdf?id=aUsx1G6RVQ).
238
+ Official implementation of [FloydNet](https://arxiv.org/pdf/2601.19094).
239
239
 
240
240
  ![Figure Pivotal Attention Mechanism for 2-Floyd/3-Floyd.](misc/pivotalattn2&3.png)
241
241
 
@@ -253,7 +253,7 @@ The repository provides:
253
253
  1. **Reusable components**: a drop-in attention/Transformer-block interface intended for integration into existing projects.
254
254
  2. **Reproduction code**: end-to-end training/evaluation pipelines to reproduce the benchmarks reported in the paper.
255
255
 
256
- For algorithmic details, hyperparameter choices, and analysis, please refer to the paper (TODO: link).
256
+ For algorithmic details, hyperparameter choices, and analysis, please refer to the [paper](https://arxiv.org/pdf/2601.19094).
257
257
 
258
258
  ---
259
259
 
@@ -360,9 +360,9 @@ uv pip install -e .
360
360
 
361
361
  ## Changelog (latest)
362
362
 
363
- - Full release with training and evaluation scripts for Graph Count, BREC, and TSP.
364
- - Added `pivotal_attention3` functional API for 3-Floyd attention.
365
- - Added additional configuration options in `PivotalAttentionBlock`.
363
+ - Added `softmax_cap` parameter to `pivotal_attention3` for improved numerical stability.
364
+ - Added LRGB example script.
365
+
366
366
 
367
367
  The full changelog is in [CHANGELOG.md](CHANGELOG.md).
368
368
 
@@ -371,12 +371,15 @@ The full changelog is in [CHANGELOG.md](CHANGELOG.md).
371
371
  If you use this code in your research, please cite the paper:
372
372
 
373
373
  ```bibtex
374
- @inproceedings{TODO,
375
- title = {TODO},
376
- author = {TODO},
377
- booktitle = {TODO},
378
- year = {TODO},
379
- url = {TODO}
374
+ @misc{yu2026floydnetlearningparadigmglobal,
375
+ title={FloydNet: A Learning Paradigm for Global Relational Reasoning},
376
+ author={Jingcheng Yu and Mingliang Zeng and Qiwei Ye},
377
+ year={2026},
378
+ eprint={2601.19094},
379
+ archivePrefix={arXiv},
380
+ primaryClass={cs.LG},
381
+ url={https://arxiv.org/abs/2601.19094},
382
+
380
383
  }
381
384
  ```
382
385
 
@@ -0,0 +1,7 @@
1
+ floydnet/__init__.py,sha256=0AnWvQ3EQu8NhEsNWIffH6ZEf3jKpJ26zTQWuV0Ersg,204
2
+ floydnet/functional.py,sha256=Usy_9AmJxzkx-HT8qJZaB0ufzVZn8fSEfcj6mgIi1Lk,6600
3
+ floydnet/transformer.py,sha256=jMTEKGkO64rlJqtB5-LFpsInkjyv8_Zo3Tvzd6ZZtPs,8678
4
+ floydnet-1.1.0.dist-info/METADATA,sha256=sExDcksEtcw0pPDJO5w9y4Oj9S_4hGfUIa1upzqedtw,18727
5
+ floydnet-1.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ floydnet-1.1.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
+ floydnet-1.1.0.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- floydnet/__init__.py,sha256=0AnWvQ3EQu8NhEsNWIffH6ZEf3jKpJ26zTQWuV0Ersg,204
2
- floydnet/functional.py,sha256=96TcAbb_YmsT0H2Hqt4EuKy1v5Xkdv0ZKNtCfRx6ab0,5716
3
- floydnet/transformer.py,sha256=jMTEKGkO64rlJqtB5-LFpsInkjyv8_Zo3Tvzd6ZZtPs,8678
4
- floydnet-1.0.0.dist-info/METADATA,sha256=h3UXHjbQPZ59gdQ5UXkpYOB0AHZ_YBc4zli-iMmixn4,18608
5
- floydnet-1.0.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
- floydnet-1.0.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
- floydnet-1.0.0.dist-info/RECORD,,