late-interaction-kernels 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- late_interaction_kernels/__init__.py +206 -0
- late_interaction_kernels/_autotune.py +129 -0
- late_interaction_kernels/_utils.py +50 -0
- late_interaction_kernels/autograd.py +244 -0
- late_interaction_kernels/backward.py +274 -0
- late_interaction_kernels/backward_csr.py +202 -0
- late_interaction_kernels/backward_unified.py +277 -0
- late_interaction_kernels/experimental/__init__.py +42 -0
- late_interaction_kernels/forward.py +273 -0
- late_interaction_kernels/fp8.py +415 -0
- late_interaction_kernels/fused_head.py +479 -0
- late_interaction_kernels/matryoshka.py +253 -0
- late_interaction_kernels/plaid.py +979 -0
- late_interaction_kernels/py.typed +0 -0
- late_interaction_kernels/pylate_compat.py +213 -0
- late_interaction_kernels/reference.py +373 -0
- late_interaction_kernels/retrieve.py +272 -0
- late_interaction_kernels/scatter.py +202 -0
- late_interaction_kernels/smooth.py +639 -0
- late_interaction_kernels/soft.py +328 -0
- late_interaction_kernels/topk.py +91 -0
- late_interaction_kernels/varlen.py +490 -0
- late_interaction_kernels/xtr.py +87 -0
- late_interaction_kernels-0.0.1.dist-info/METADATA +252 -0
- late_interaction_kernels-0.0.1.dist-info/RECORD +27 -0
- late_interaction_kernels-0.0.1.dist-info/WHEEL +4 -0
- late_interaction_kernels-0.0.1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""Fused Triton kernels for late-interaction (MaxSim) scoring.
|
|
2
|
+
|
|
3
|
+
Common entry points::
|
|
4
|
+
|
|
5
|
+
from late_interaction_kernels import patch_pylate, MaxSimScorer, retrieve
|
|
6
|
+
|
|
7
|
+
patch_pylate() # PyLate drop-in
|
|
8
|
+
scorer = MaxSimScorer(normalize=True) # nn.Module, autograd-aware
|
|
9
|
+
scores, idx = retrieve(Q, D, top_k=100)
|
|
10
|
+
|
|
11
|
+
See the README for the full API and benchmarks.
|
|
12
|
+
FP8 helpers live in ``late_interaction_kernels.fp8``.
|
|
13
|
+
Research kernels live in ``late_interaction_kernels.experimental``.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from importlib.metadata import PackageNotFoundError
|
|
17
|
+
from importlib.metadata import version as _pkg_version
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
__version__ = _pkg_version("late-interaction-kernels")
|
|
21
|
+
except PackageNotFoundError: # pragma: no cover — running from a source tree without install
|
|
22
|
+
__version__ = "0.0.0+unknown"
|
|
23
|
+
|
|
24
|
+
# The kernels need Triton (Linux + CUDA). On macOS / Windows we still want
|
|
25
|
+
# `import late_interaction_kernels` to succeed so users can develop against
|
|
26
|
+
# the pure-PyTorch reference and `MaxSimScorer` / `retrieve` fallbacks.
|
|
27
|
+
try:
|
|
28
|
+
import triton # noqa: F401
|
|
29
|
+
|
|
30
|
+
_HAS_TRITON = True
|
|
31
|
+
except ImportError: # pragma: no cover
|
|
32
|
+
_HAS_TRITON = False
|
|
33
|
+
|
|
34
|
+
if _HAS_TRITON:
|
|
35
|
+
from .autograd import (
|
|
36
|
+
get_backward_method,
|
|
37
|
+
maxsim,
|
|
38
|
+
maxsim_inference,
|
|
39
|
+
set_backward_method,
|
|
40
|
+
)
|
|
41
|
+
from .fp8 import maxsim_inference_fp8
|
|
42
|
+
from .fused_head import maxsim_from_hidden, maxsim_from_hidden_train
|
|
43
|
+
from .plaid import (
|
|
44
|
+
maxsim_residual,
|
|
45
|
+
maxsim_residual_varlen,
|
|
46
|
+
plaid_approx_score,
|
|
47
|
+
)
|
|
48
|
+
from .pylate_compat import patch_pylate, unpatch_pylate
|
|
49
|
+
from .scatter import maxsim_inference_scatter
|
|
50
|
+
from .varlen import maxsim_varlen
|
|
51
|
+
else: # pragma: no cover
|
|
52
|
+
|
|
53
|
+
def _needs_triton(*_args, **_kwargs): # type: ignore[no-redef]
|
|
54
|
+
raise RuntimeError(
|
|
55
|
+
"late-interaction-kernels GPU kernels require Triton, which isn't "
|
|
56
|
+
"installed on this platform. Install a CUDA-enabled Triton (Linux only) "
|
|
57
|
+
"or use `late_interaction_kernels.reference` for the pure-PyTorch path."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
maxsim = maxsim_inference = _needs_triton
|
|
61
|
+
maxsim_from_hidden = maxsim_from_hidden_train = _needs_triton
|
|
62
|
+
maxsim_inference_fp8 = _needs_triton
|
|
63
|
+
maxsim_varlen = _needs_triton
|
|
64
|
+
plaid_approx_score = _needs_triton
|
|
65
|
+
maxsim_residual = maxsim_residual_varlen = _needs_triton
|
|
66
|
+
maxsim_inference_scatter = _needs_triton
|
|
67
|
+
set_backward_method = get_backward_method = _needs_triton
|
|
68
|
+
patch_pylate = unpatch_pylate = _needs_triton
|
|
69
|
+
|
|
70
|
+
# `MaxSimScorer` and `retrieve` are always importable: they fall back to the
|
|
71
|
+
# pure-PyTorch reference on platforms without Triton, so training and
|
|
72
|
+
# retrieval code can be unit-tested locally.
|
|
73
|
+
from . import reference # noqa: E402,F401
|
|
74
|
+
from .retrieve import MaxSimScorer, retrieve # noqa: E402
|
|
75
|
+
|
|
76
|
+
# Symbols moved out of the top level. Still importable, with a
|
|
77
|
+
# `DeprecationWarning`. Scheduled for removal in a future release.
|
|
78
|
+
_DEPRECATED_EXPERIMENTAL = {
|
|
79
|
+
"maxsim_matryoshka": "late_interaction_kernels.experimental",
|
|
80
|
+
"maxsim_xtr": "late_interaction_kernels.experimental",
|
|
81
|
+
"soft_maxsim": "late_interaction_kernels.experimental",
|
|
82
|
+
"smooth_maxsim": "late_interaction_kernels.experimental",
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
_DEPRECATED_FP8_HELPERS = {
|
|
86
|
+
"quantize_fp8_per_tensor": "late_interaction_kernels.fp8",
|
|
87
|
+
"quantize_fp8_per_token": "late_interaction_kernels.fp8",
|
|
88
|
+
"dequantize_fp8_per_tensor": "late_interaction_kernels.fp8",
|
|
89
|
+
"dequantize_fp8_per_token": "late_interaction_kernels.fp8",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def __getattr__(name: str):
|
|
94
|
+
"""PEP 562 — re-export deprecated / moved symbols with a warning."""
|
|
95
|
+
import warnings
|
|
96
|
+
|
|
97
|
+
if name == "maxsim_forward":
|
|
98
|
+
warnings.warn(
|
|
99
|
+
"`late_interaction_kernels.maxsim_forward` is deprecated. Use "
|
|
100
|
+
"`maxsim_inference` for reranking, `maxsim` for gradients, or "
|
|
101
|
+
"import the primitive from `late_interaction_kernels.forward`.",
|
|
102
|
+
DeprecationWarning,
|
|
103
|
+
stacklevel=2,
|
|
104
|
+
)
|
|
105
|
+
if _HAS_TRITON:
|
|
106
|
+
from .forward import maxsim_forward as _mf
|
|
107
|
+
|
|
108
|
+
return _mf
|
|
109
|
+
return _needs_triton
|
|
110
|
+
|
|
111
|
+
if name == "maxsim_topk":
|
|
112
|
+
warnings.warn(
|
|
113
|
+
"`maxsim_topk` is deprecated; use `retrieve(Q, D, top_k=...)` "
|
|
114
|
+
"(same semantics, transparent CPU fallback). Still importable from "
|
|
115
|
+
"`late_interaction_kernels.topk`.",
|
|
116
|
+
DeprecationWarning,
|
|
117
|
+
stacklevel=2,
|
|
118
|
+
)
|
|
119
|
+
if _HAS_TRITON:
|
|
120
|
+
from .topk import maxsim_topk as _mt
|
|
121
|
+
|
|
122
|
+
return _mt
|
|
123
|
+
return _needs_triton
|
|
124
|
+
|
|
125
|
+
if name == "maxsim_residual_inference":
|
|
126
|
+
warnings.warn(
|
|
127
|
+
"`maxsim_residual_inference` is deprecated; `maxsim_residual` "
|
|
128
|
+
"auto-skips the argmax save when `Q.requires_grad=False`.",
|
|
129
|
+
DeprecationWarning,
|
|
130
|
+
stacklevel=2,
|
|
131
|
+
)
|
|
132
|
+
if _HAS_TRITON:
|
|
133
|
+
from .plaid import maxsim_residual_inference as _mri
|
|
134
|
+
|
|
135
|
+
return _mri
|
|
136
|
+
return _needs_triton
|
|
137
|
+
|
|
138
|
+
if name == "maxsim_varlen_inference":
|
|
139
|
+
warnings.warn(
|
|
140
|
+
"`maxsim_varlen_inference` is deprecated; `maxsim_varlen` "
|
|
141
|
+
"auto-skips the argmax save when neither input requires grad.",
|
|
142
|
+
DeprecationWarning,
|
|
143
|
+
stacklevel=2,
|
|
144
|
+
)
|
|
145
|
+
if _HAS_TRITON:
|
|
146
|
+
from .varlen import maxsim_varlen_inference as _mvi
|
|
147
|
+
|
|
148
|
+
return _mvi
|
|
149
|
+
return _needs_triton
|
|
150
|
+
|
|
151
|
+
if name in _DEPRECATED_EXPERIMENTAL:
|
|
152
|
+
new_home = _DEPRECATED_EXPERIMENTAL[name]
|
|
153
|
+
warnings.warn(
|
|
154
|
+
f"`late_interaction_kernels.{name}` moved to `{new_home}`. Use `from {new_home} import {name}`.",
|
|
155
|
+
DeprecationWarning,
|
|
156
|
+
stacklevel=2,
|
|
157
|
+
)
|
|
158
|
+
if _HAS_TRITON:
|
|
159
|
+
from . import experimental
|
|
160
|
+
|
|
161
|
+
return getattr(experimental, name)
|
|
162
|
+
return _needs_triton
|
|
163
|
+
|
|
164
|
+
if name in _DEPRECATED_FP8_HELPERS:
|
|
165
|
+
new_home = _DEPRECATED_FP8_HELPERS[name]
|
|
166
|
+
warnings.warn(
|
|
167
|
+
f"`late_interaction_kernels.{name}` moved to `{new_home}`. Use `from {new_home} import {name}`.",
|
|
168
|
+
DeprecationWarning,
|
|
169
|
+
stacklevel=2,
|
|
170
|
+
)
|
|
171
|
+
if _HAS_TRITON:
|
|
172
|
+
from . import fp8
|
|
173
|
+
|
|
174
|
+
return getattr(fp8, name)
|
|
175
|
+
return _needs_triton
|
|
176
|
+
|
|
177
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
__all__ = [
|
|
181
|
+
"__version__",
|
|
182
|
+
# high-level
|
|
183
|
+
"MaxSimScorer",
|
|
184
|
+
"retrieve",
|
|
185
|
+
"patch_pylate",
|
|
186
|
+
"unpatch_pylate",
|
|
187
|
+
# core MaxSim
|
|
188
|
+
"maxsim",
|
|
189
|
+
"maxsim_inference",
|
|
190
|
+
"maxsim_varlen",
|
|
191
|
+
# reranking on packed batches
|
|
192
|
+
"maxsim_inference_scatter",
|
|
193
|
+
# fused D-side head
|
|
194
|
+
"maxsim_from_hidden",
|
|
195
|
+
"maxsim_from_hidden_train",
|
|
196
|
+
# PLAID / ColBERTv2
|
|
197
|
+
"plaid_approx_score",
|
|
198
|
+
"maxsim_residual",
|
|
199
|
+
"maxsim_residual_varlen",
|
|
200
|
+
# FP8 inference
|
|
201
|
+
"maxsim_inference_fp8",
|
|
202
|
+
# configuration
|
|
203
|
+
"set_backward_method",
|
|
204
|
+
"get_backward_method",
|
|
205
|
+
"reference",
|
|
206
|
+
]
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Autotune configs per GPU family.
|
|
2
|
+
|
|
3
|
+
Triton autotune runs each candidate once on the first call for a given key,
|
|
4
|
+
caches the winner, and reuses it forever after. We keep lists short — each
|
|
5
|
+
extra config costs one real launch.
|
|
6
|
+
|
|
7
|
+
Family rules of thumb (verified on H100 / A100 and conservative on the rest):
|
|
8
|
+
- Small ``d`` (≤ 128): prefer `BLOCK_Q=32-64, BLOCK_D=64-128`.
|
|
9
|
+
- Large ``d`` (≥ 512): shrink blocks so the fp16 `Q`/`D` tiles plus the fp32
|
|
10
|
+
`S` tile fit in the SM's shared-memory budget.
|
|
11
|
+
- Hopper loves `num_stages ≥ 3` (warp specialization + async copy).
|
|
12
|
+
- Ampere / Ada are happiest with `num_stages=2`.
|
|
13
|
+
|
|
14
|
+
Per-family SRAM budgets (KiB of shared memory the kernel can actually use):
|
|
15
|
+
- Hopper (H100 / H200): 228
|
|
16
|
+
- Ampere (A100): 164
|
|
17
|
+
- Ampere consumer (3090, A10): 100
|
|
18
|
+
- Ada (L4, L40, RTX 4090): 100
|
|
19
|
+
- Unknown / older: 48 (safe floor)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import inspect
|
|
25
|
+
|
|
26
|
+
import triton
|
|
27
|
+
|
|
28
|
+
from ._utils import detect_gpu
|
|
29
|
+
|
|
30
|
+
# Warp specialization (FA-3 style) requires Triton 3.2+. The
|
|
31
|
+
# ``num_consumer_groups`` / ``num_buffers_warp_spec`` kwargs on
|
|
32
|
+
# ``triton.Config`` opt a kernel into producer-consumer warp specialization
|
|
33
|
+
# so loads overlap cleanly with ``tl.dot``. We feature-detect so we still
|
|
34
|
+
# run on older Triton.
|
|
35
|
+
try:
|
|
36
|
+
_CFG_PARAMS = set(inspect.signature(triton.Config).parameters)
|
|
37
|
+
except (TypeError, ValueError): # pragma: no cover
|
|
38
|
+
_CFG_PARAMS = set()
|
|
39
|
+
_HAS_WARP_SPEC = {"num_consumer_groups", "num_buffers_warp_spec"} <= _CFG_PARAMS
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _cfg(kwargs, *, num_warps, num_stages, warp_spec=False):
|
|
43
|
+
"""Build a ``triton.Config`` and quietly opt-in to warp specialization
|
|
44
|
+
when the running Triton supports it.
|
|
45
|
+
"""
|
|
46
|
+
extras = {}
|
|
47
|
+
if warp_spec and _HAS_WARP_SPEC:
|
|
48
|
+
extras["num_consumer_groups"] = 2
|
|
49
|
+
extras["num_buffers_warp_spec"] = num_stages
|
|
50
|
+
return triton.Config(kwargs, num_warps=num_warps, num_stages=num_stages, **extras)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _small_d_hopper():
|
|
54
|
+
return [
|
|
55
|
+
_cfg({"BLOCK_Q": 32, "BLOCK_D": 64}, num_warps=4, num_stages=3),
|
|
56
|
+
_cfg({"BLOCK_Q": 32, "BLOCK_D": 128}, num_warps=8, num_stages=3),
|
|
57
|
+
_cfg({"BLOCK_Q": 64, "BLOCK_D": 64}, num_warps=4, num_stages=3),
|
|
58
|
+
_cfg({"BLOCK_Q": 64, "BLOCK_D": 128}, num_warps=8, num_stages=3),
|
|
59
|
+
_cfg({"BLOCK_Q": 128, "BLOCK_D": 64}, num_warps=8, num_stages=2),
|
|
60
|
+
_cfg({"BLOCK_Q": 128, "BLOCK_D": 128}, num_warps=8, num_stages=2),
|
|
61
|
+
# Warp-specialized shortlist: producer warp group streams Q/D tiles
|
|
62
|
+
# into shared memory while consumer group(s) run back-to-back
|
|
63
|
+
# ``tl.dot`` + running-max. No-ops on Triton < 3.2.
|
|
64
|
+
_cfg({"BLOCK_Q": 64, "BLOCK_D": 128}, num_warps=8, num_stages=3, warp_spec=True),
|
|
65
|
+
_cfg({"BLOCK_Q": 128, "BLOCK_D": 128}, num_warps=8, num_stages=3, warp_spec=True),
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _small_d_ampere():
|
|
70
|
+
"""Works on A100, A10, A40, 3090, and is a safe default for Ada (L4, L40, 4090)."""
|
|
71
|
+
return [
|
|
72
|
+
_cfg({"BLOCK_Q": 32, "BLOCK_D": 64}, num_warps=4, num_stages=2),
|
|
73
|
+
_cfg({"BLOCK_Q": 32, "BLOCK_D": 128}, num_warps=8, num_stages=2),
|
|
74
|
+
_cfg({"BLOCK_Q": 64, "BLOCK_D": 64}, num_warps=4, num_stages=2),
|
|
75
|
+
_cfg({"BLOCK_Q": 64, "BLOCK_D": 128}, num_warps=8, num_stages=2),
|
|
76
|
+
_cfg({"BLOCK_Q": 128, "BLOCK_D": 64}, num_warps=8, num_stages=1),
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _large_d_configs():
|
|
81
|
+
"""Small-block configs for d ≥ 512 — fit any GPU, any SM."""
|
|
82
|
+
return [
|
|
83
|
+
_cfg({"BLOCK_Q": 16, "BLOCK_D": 16}, num_warps=2, num_stages=2),
|
|
84
|
+
_cfg({"BLOCK_Q": 16, "BLOCK_D": 32}, num_warps=2, num_stages=2),
|
|
85
|
+
_cfg({"BLOCK_Q": 32, "BLOCK_D": 16}, num_warps=2, num_stages=2),
|
|
86
|
+
_cfg({"BLOCK_Q": 32, "BLOCK_D": 32}, num_warps=4, num_stages=2),
|
|
87
|
+
_cfg({"BLOCK_Q": 32, "BLOCK_D": 64}, num_warps=4, num_stages=2),
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
_SRAM_KIB_BY_FAMILY = {
|
|
92
|
+
"hopper": 228,
|
|
93
|
+
"a100": 164,
|
|
94
|
+
"ampere": 100,
|
|
95
|
+
"ada": 100,
|
|
96
|
+
"generic": 48,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def forward_configs():
|
|
101
|
+
gpu = detect_gpu()
|
|
102
|
+
base = _large_d_configs()
|
|
103
|
+
if gpu == "hopper":
|
|
104
|
+
return base + _small_d_hopper()
|
|
105
|
+
if gpu in ("a100", "ampere", "ada"):
|
|
106
|
+
return base + _small_d_ampere()
|
|
107
|
+
return base # minimal safe shortlist for unknown GPUs
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def prune_forward(configs, named_args, **kwargs):
|
|
111
|
+
"""Drop configs that overflow shared memory or are oversized for the problem."""
|
|
112
|
+
Lq = named_args.get("Lq", 32)
|
|
113
|
+
d = named_args.get("d", 128)
|
|
114
|
+
gpu = detect_gpu()
|
|
115
|
+
# Reserve 8 KiB for Triton scratch; the rest is ours.
|
|
116
|
+
sram_budget = (_SRAM_KIB_BY_FAMILY.get(gpu, 48) - 8) * 1024
|
|
117
|
+
|
|
118
|
+
keep = []
|
|
119
|
+
for cfg in configs:
|
|
120
|
+
bq, bd = cfg.kwargs["BLOCK_Q"], cfg.kwargs["BLOCK_D"]
|
|
121
|
+
# fp16/bf16 Q tile + fp16/bf16 D tile + fp32 S tile.
|
|
122
|
+
need = (bq * d + bd * d) * 2 + bq * bd * 4
|
|
123
|
+
if need > sram_budget:
|
|
124
|
+
continue
|
|
125
|
+
if bq > 2 * Lq:
|
|
126
|
+
continue
|
|
127
|
+
keep.append(cfg)
|
|
128
|
+
# Always return at least two configs so autotune has something to compare.
|
|
129
|
+
return keep or configs[:2]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Small helpers shared across kernels."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def next_pow2(x: int) -> int:
|
|
11
|
+
"""Smallest power of two >= x. `next_pow2(0)` returns 1."""
|
|
12
|
+
if x <= 1:
|
|
13
|
+
return 1
|
|
14
|
+
return 1 << (x - 1).bit_length()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@functools.lru_cache(maxsize=1)
|
|
18
|
+
def detect_gpu() -> str:
|
|
19
|
+
"""Return a short GPU family string: 'hopper' | 'a100' | 'ada' | 'ampere' | 'generic'."""
|
|
20
|
+
if not torch.cuda.is_available():
|
|
21
|
+
return "generic"
|
|
22
|
+
name = torch.cuda.get_device_name().lower()
|
|
23
|
+
if "h100" in name or "h200" in name:
|
|
24
|
+
return "hopper"
|
|
25
|
+
if "a100" in name:
|
|
26
|
+
return "a100"
|
|
27
|
+
if "l4" in name or "l40" in name or "rtx 40" in name:
|
|
28
|
+
return "ada"
|
|
29
|
+
if "3090" in name or "a10" in name or "a40" in name:
|
|
30
|
+
return "ampere"
|
|
31
|
+
return "generic"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def ensure_contiguous_last(x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
"""Make sure the last dim is contiguous — cheap path for most inputs."""
|
|
36
|
+
if x.stride(-1) == 1:
|
|
37
|
+
return x
|
|
38
|
+
return x.contiguous()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def pick_compute_dtype(Q: torch.Tensor, D: torch.Tensor) -> torch.dtype:
|
|
42
|
+
"""Pick the compute dtype for `tl.dot`.
|
|
43
|
+
|
|
44
|
+
We honor user intent: if both tensors are fp16/bf16, dot runs in that dtype
|
|
45
|
+
with fp32 accumulator. If either is fp32 we fall back to fp16 on the tile
|
|
46
|
+
(fp32 GEMM doesn't go through tensor cores on H100 anyway).
|
|
47
|
+
"""
|
|
48
|
+
if Q.dtype == torch.bfloat16 or D.dtype == torch.bfloat16:
|
|
49
|
+
return torch.bfloat16
|
|
50
|
+
return torch.float16
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""User-facing autograd wrapper for the fused MaxSim kernel."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .backward import maxsim_backward
|
|
11
|
+
from .backward_unified import maxsim_backward_unified
|
|
12
|
+
from .forward import _run_forward, maxsim_forward
|
|
13
|
+
|
|
14
|
+
_BACKWARD_METHOD = "auto" # module-level toggle, flipped by `set_backward_method`
|
|
15
|
+
|
|
16
|
+
_VALID_METHODS = ("auto", "atomic", "csr", "unified")
|
|
17
|
+
|
|
18
|
+
# One-shot flag so we don't spam the user's logs if they happen to pass
|
|
19
|
+
# unnormalized inputs inside a tight training loop.
|
|
20
|
+
_WARNED_UNNORMALIZED = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def set_backward_method(method: str) -> None:
|
|
24
|
+
"""Set the process-wide default ``grad_D`` path.
|
|
25
|
+
|
|
26
|
+
Prefer the per-call ``backward=`` kwarg on :func:`maxsim` and
|
|
27
|
+
:class:`~late_interaction_kernels.MaxSimScorer`. This global is kept
|
|
28
|
+
for back-compat and for pinning a single method across a benchmark run.
|
|
29
|
+
|
|
30
|
+
Values:
|
|
31
|
+
|
|
32
|
+
* ``"auto"`` — ``"unified"`` for almost every shape; ``"csr"`` for
|
|
33
|
+
very high ``grad_D`` contention (``Nq ≥ 256 ∧ Nd ≥ 256 ∧ Lq ≤ 64``).
|
|
34
|
+
* ``"unified"`` — single-pass fused ``grad_Q + grad_D`` kernel.
|
|
35
|
+
* ``"csr"`` — scatter-free bucketed reduction; bitwise-deterministic.
|
|
36
|
+
* ``"atomic"`` — legacy two-pass with fp32 ``tl.atomic_add``.
|
|
37
|
+
"""
|
|
38
|
+
global _BACKWARD_METHOD
|
|
39
|
+
if method not in _VALID_METHODS:
|
|
40
|
+
raise ValueError(f"method must be one of {_VALID_METHODS}, got {method!r}")
|
|
41
|
+
_BACKWARD_METHOD = method
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_backward_method() -> str:
|
|
45
|
+
return _BACKWARD_METHOD
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _maybe_warn_unnormalized(Q: torch.Tensor) -> None:
|
|
49
|
+
"""Warn once when ``normalize=False`` is paired with non-normalized Q.
|
|
50
|
+
|
|
51
|
+
ColBERT / ColPali / LateOn always score L2-normalized tokens. Calling
|
|
52
|
+
``maxsim`` on raw encoder outputs silently produces different score
|
|
53
|
+
scales than PyLate. Silence with ``LIK_SUPPRESS_NORM_WARN=1``.
|
|
54
|
+
"""
|
|
55
|
+
global _WARNED_UNNORMALIZED
|
|
56
|
+
if _WARNED_UNNORMALIZED or os.environ.get("LIK_SUPPRESS_NORM_WARN", "0") == "1":
|
|
57
|
+
return
|
|
58
|
+
# Cheap sanity check: a handful of token norms.
|
|
59
|
+
with torch.no_grad():
|
|
60
|
+
sample = Q.detach()
|
|
61
|
+
# Flatten leading dims, inspect up to the first 64 tokens.
|
|
62
|
+
sample = sample.reshape(-1, sample.shape[-1])[:64]
|
|
63
|
+
if sample.numel() == 0:
|
|
64
|
+
return
|
|
65
|
+
norms = sample.float().norm(dim=-1)
|
|
66
|
+
med = norms.median().item()
|
|
67
|
+
if not (0.9 <= med <= 1.1):
|
|
68
|
+
_WARNED_UNNORMALIZED = True
|
|
69
|
+
warnings.warn(
|
|
70
|
+
f"late-interaction-kernels: `maxsim(..., normalize=False)` but Q's median L2 norm "
|
|
71
|
+
f"is {med:.3f} (ColBERT-style models expect ≈1.0). Pass `normalize=True` to fuse "
|
|
72
|
+
"the L2-norm into the kernel, or pre-normalize with `F.normalize(Q, dim=-1)`. "
|
|
73
|
+
"Silence with `LIK_SUPPRESS_NORM_WARN=1`.",
|
|
74
|
+
UserWarning,
|
|
75
|
+
stacklevel=3,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class _MaxSimFn(torch.autograd.Function):
|
|
80
|
+
"""Fused MaxSim with saved argmax, 3-D inputs."""
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def forward(ctx, Q, D, q_mask, d_mask, normalize, backward_method):
|
|
84
|
+
scores, argmax = _run_forward(Q, D, q_mask, d_mask, save_argmax=True, normalize=normalize)
|
|
85
|
+
ctx.save_for_backward(Q, D, argmax, q_mask, d_mask)
|
|
86
|
+
ctx.backward_method = backward_method
|
|
87
|
+
ctx.normalize = normalize
|
|
88
|
+
return scores
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def backward(ctx, grad_scores):
|
|
92
|
+
Q, D, argmax, q_mask, d_mask = ctx.saved_tensors
|
|
93
|
+
grad_scores = grad_scores.contiguous().to(torch.float32)
|
|
94
|
+
|
|
95
|
+
# `auto` -> `unified` for typical training shapes; `csr` only when
|
|
96
|
+
# `grad_D` contention is very high (large square batches, short Lq).
|
|
97
|
+
method = ctx.backward_method
|
|
98
|
+
if method == "auto":
|
|
99
|
+
Nq, Lq, _ = Q.shape
|
|
100
|
+
Nd = D.shape[0]
|
|
101
|
+
high_contention = Nq >= 256 and Nd >= 256 and Lq <= 64
|
|
102
|
+
method = "csr" if high_contention else "unified"
|
|
103
|
+
|
|
104
|
+
def _bwd(Qt, Dt):
|
|
105
|
+
if method == "unified":
|
|
106
|
+
return maxsim_backward_unified(grad_scores, Qt, Dt, argmax, q_mask=q_mask, method="atomic")
|
|
107
|
+
return maxsim_backward(
|
|
108
|
+
grad_scores,
|
|
109
|
+
Qt,
|
|
110
|
+
Dt,
|
|
111
|
+
argmax,
|
|
112
|
+
q_mask,
|
|
113
|
+
d_mask,
|
|
114
|
+
method=method,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if ctx.normalize:
|
|
118
|
+
# The forward computed scores against Q_hat = Q / ||Q|| and D_hat = D / ||D||.
|
|
119
|
+
# We need grad w.r.t. the *unnormalized* Q and D. We get that by
|
|
120
|
+
# (a) running the existing backward against the normalized tensors to get
|
|
121
|
+
# grad_Q_hat, grad_D_hat, then (b) applying the L2-normalize Jacobian.
|
|
122
|
+
q_norm = torch.linalg.vector_norm(Q, dim=-1, keepdim=True).clamp_min(1e-6)
|
|
123
|
+
d_norm = torch.linalg.vector_norm(D, dim=-1, keepdim=True).clamp_min(1e-6)
|
|
124
|
+
Q_hat = Q / q_norm
|
|
125
|
+
D_hat = D / d_norm
|
|
126
|
+
grad_Qh, grad_Dh = _bwd(Q_hat, D_hat)
|
|
127
|
+
# d Qhat / d Q = (I - Qhat Qhat^T) / ||Q||
|
|
128
|
+
grad_Q = (grad_Qh - (grad_Qh * Q_hat).sum(-1, keepdim=True) * Q_hat) / q_norm
|
|
129
|
+
grad_D = (grad_Dh - (grad_Dh * D_hat).sum(-1, keepdim=True) * D_hat) / d_norm
|
|
130
|
+
else:
|
|
131
|
+
grad_Q, grad_D = _bwd(Q, D)
|
|
132
|
+
# masks, normalize, backward_method receive no gradient
|
|
133
|
+
return grad_Q, grad_D, None, None, None, None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def maxsim(
|
|
137
|
+
Q: torch.Tensor,
|
|
138
|
+
D: torch.Tensor,
|
|
139
|
+
q_mask: torch.Tensor | None = None,
|
|
140
|
+
d_mask: torch.Tensor | None = None,
|
|
141
|
+
*,
|
|
142
|
+
normalize: bool = False,
|
|
143
|
+
backward: str | None = None,
|
|
144
|
+
) -> torch.Tensor:
|
|
145
|
+
"""Differentiable fused MaxSim. Drop-in for PyLate's ``colbert_scores``.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
Q: ``[Nq, Lq, d]`` or ``[Lq, d]``.
|
|
149
|
+
D: ``[Nd, Ld, d]`` or ``[Ld, d]``.
|
|
150
|
+
q_mask, d_mask: bool tensors (``True`` = valid token).
|
|
151
|
+
normalize: L2-normalize Q and D per-token inside the kernel. Set to
|
|
152
|
+
``True`` for ColBERT / ColPali / LateOn-style scoring.
|
|
153
|
+
backward: per-call override of the ``grad_D`` strategy
|
|
154
|
+
(``"auto" | "unified" | "csr" | "atomic"``). ``None`` defers
|
|
155
|
+
to :func:`set_backward_method`.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
scores: ``[Nq, Nd]`` fp32, squeezed to match 2-D inputs.
|
|
159
|
+
|
|
160
|
+
Inputs can be fp16 / bf16 / fp32 (fp32 accumulator). Gradients flow
|
|
161
|
+
into Q and D; masks are non-differentiable.
|
|
162
|
+
"""
|
|
163
|
+
q_was_2d = Q.dim() == 2
|
|
164
|
+
d_was_2d = D.dim() == 2
|
|
165
|
+
if q_was_2d:
|
|
166
|
+
Q = Q.unsqueeze(0)
|
|
167
|
+
if d_was_2d:
|
|
168
|
+
D = D.unsqueeze(0)
|
|
169
|
+
if q_mask is not None and q_mask.dim() == 1:
|
|
170
|
+
q_mask = q_mask.unsqueeze(0)
|
|
171
|
+
if d_mask is not None and d_mask.dim() == 1:
|
|
172
|
+
d_mask = d_mask.unsqueeze(0)
|
|
173
|
+
|
|
174
|
+
# Shape / device contract — fail fast with a clear message so user code
|
|
175
|
+
# doesn't silently corrupt memory or produce garbage scores.
|
|
176
|
+
if Q.shape[-1] != D.shape[-1]:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Q and D must share the embedding dim; got Q.shape[-1]={Q.shape[-1]} "
|
|
179
|
+
f"vs D.shape[-1]={D.shape[-1]}."
|
|
180
|
+
)
|
|
181
|
+
if Q.device != D.device:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Q and D must be on the same device; got Q.device={Q.device} vs D.device={D.device}."
|
|
184
|
+
)
|
|
185
|
+
if q_mask is not None and q_mask.device != Q.device:
|
|
186
|
+
raise ValueError(f"q_mask must be on the same device as Q; got {q_mask.device} vs {Q.device}.")
|
|
187
|
+
if d_mask is not None and d_mask.device != D.device:
|
|
188
|
+
raise ValueError(f"d_mask must be on the same device as D; got {d_mask.device} vs {D.device}.")
|
|
189
|
+
|
|
190
|
+
if backward is None:
|
|
191
|
+
method = _BACKWARD_METHOD
|
|
192
|
+
elif backward not in _VALID_METHODS:
|
|
193
|
+
raise ValueError(f"backward= must be one of {_VALID_METHODS} or None, got {backward!r}")
|
|
194
|
+
else:
|
|
195
|
+
method = backward
|
|
196
|
+
|
|
197
|
+
if not normalize:
|
|
198
|
+
_maybe_warn_unnormalized(Q)
|
|
199
|
+
|
|
200
|
+
Q = Q.contiguous()
|
|
201
|
+
D = D.contiguous()
|
|
202
|
+
q_mask_i8 = q_mask.contiguous().to(torch.int8) if q_mask is not None else None
|
|
203
|
+
d_mask_i8 = d_mask.contiguous().to(torch.int8) if d_mask is not None else None
|
|
204
|
+
|
|
205
|
+
scores = _MaxSimFn.apply(Q, D, q_mask_i8, d_mask_i8, normalize, method)
|
|
206
|
+
|
|
207
|
+
if q_was_2d and d_was_2d:
|
|
208
|
+
return scores.reshape(())
|
|
209
|
+
if q_was_2d:
|
|
210
|
+
return scores.squeeze(0)
|
|
211
|
+
if d_was_2d:
|
|
212
|
+
return scores.squeeze(-1)
|
|
213
|
+
return scores
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def maxsim_inference(
|
|
217
|
+
Q: torch.Tensor,
|
|
218
|
+
D: torch.Tensor,
|
|
219
|
+
q_mask: torch.Tensor | None = None,
|
|
220
|
+
d_mask: torch.Tensor | None = None,
|
|
221
|
+
*,
|
|
222
|
+
normalize: bool = False,
|
|
223
|
+
) -> torch.Tensor:
|
|
224
|
+
"""Inference-only MaxSim — like :func:`maxsim` but no saved argmax."""
|
|
225
|
+
if Q.shape[-1] != D.shape[-1]:
|
|
226
|
+
raise ValueError(
|
|
227
|
+
f"Q and D must share the embedding dim; got Q.shape[-1]={Q.shape[-1]} "
|
|
228
|
+
f"vs D.shape[-1]={D.shape[-1]}."
|
|
229
|
+
)
|
|
230
|
+
if Q.device != D.device:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Q and D must be on the same device; got Q.device={Q.device} vs D.device={D.device}."
|
|
233
|
+
)
|
|
234
|
+
if not normalize:
|
|
235
|
+
_maybe_warn_unnormalized(Q)
|
|
236
|
+
scores, _ = maxsim_forward(
|
|
237
|
+
Q,
|
|
238
|
+
D,
|
|
239
|
+
q_mask=q_mask,
|
|
240
|
+
d_mask=d_mask,
|
|
241
|
+
save_argmax=False,
|
|
242
|
+
normalize=normalize,
|
|
243
|
+
)
|
|
244
|
+
return scores
|