late-interaction-kernels 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,206 @@
1
+ """Fused Triton kernels for late-interaction (MaxSim) scoring.
2
+
3
+ Common entry points::
4
+
5
+ from late_interaction_kernels import patch_pylate, MaxSimScorer, retrieve
6
+
7
+ patch_pylate() # PyLate drop-in
8
+ scorer = MaxSimScorer(normalize=True) # nn.Module, autograd-aware
9
+ scores, idx = retrieve(Q, D, top_k=100)
10
+
11
+ See the README for the full API and benchmarks.
12
+ FP8 helpers live in ``late_interaction_kernels.fp8``.
13
+ Research kernels live in ``late_interaction_kernels.experimental``.
14
+ """
15
+
16
+ from importlib.metadata import PackageNotFoundError
17
+ from importlib.metadata import version as _pkg_version
18
+
19
+ try:
20
+ __version__ = _pkg_version("late-interaction-kernels")
21
+ except PackageNotFoundError: # pragma: no cover — running from a source tree without install
22
+ __version__ = "0.0.0+unknown"
23
+
24
+ # The kernels need Triton (Linux + CUDA). On macOS / Windows we still want
25
+ # `import late_interaction_kernels` to succeed so users can develop against
26
+ # the pure-PyTorch reference and `MaxSimScorer` / `retrieve` fallbacks.
27
+ try:
28
+ import triton # noqa: F401
29
+
30
+ _HAS_TRITON = True
31
+ except ImportError: # pragma: no cover
32
+ _HAS_TRITON = False
33
+
34
+ if _HAS_TRITON:
35
+ from .autograd import (
36
+ get_backward_method,
37
+ maxsim,
38
+ maxsim_inference,
39
+ set_backward_method,
40
+ )
41
+ from .fp8 import maxsim_inference_fp8
42
+ from .fused_head import maxsim_from_hidden, maxsim_from_hidden_train
43
+ from .plaid import (
44
+ maxsim_residual,
45
+ maxsim_residual_varlen,
46
+ plaid_approx_score,
47
+ )
48
+ from .pylate_compat import patch_pylate, unpatch_pylate
49
+ from .scatter import maxsim_inference_scatter
50
+ from .varlen import maxsim_varlen
51
+ else: # pragma: no cover
52
+
53
+ def _needs_triton(*_args, **_kwargs): # type: ignore[no-redef]
54
+ raise RuntimeError(
55
+ "late-interaction-kernels GPU kernels require Triton, which isn't "
56
+ "installed on this platform. Install a CUDA-enabled Triton (Linux only) "
57
+ "or use `late_interaction_kernels.reference` for the pure-PyTorch path."
58
+ )
59
+
60
+ maxsim = maxsim_inference = _needs_triton
61
+ maxsim_from_hidden = maxsim_from_hidden_train = _needs_triton
62
+ maxsim_inference_fp8 = _needs_triton
63
+ maxsim_varlen = _needs_triton
64
+ plaid_approx_score = _needs_triton
65
+ maxsim_residual = maxsim_residual_varlen = _needs_triton
66
+ maxsim_inference_scatter = _needs_triton
67
+ set_backward_method = get_backward_method = _needs_triton
68
+ patch_pylate = unpatch_pylate = _needs_triton
69
+
70
+ # `MaxSimScorer` and `retrieve` are always importable: they fall back to the
71
+ # pure-PyTorch reference on platforms without Triton, so training and
72
+ # retrieval code can be unit-tested locally.
73
+ from . import reference # noqa: E402,F401
74
+ from .retrieve import MaxSimScorer, retrieve # noqa: E402
75
+
76
+ # Symbols moved out of the top level. Still importable, with a
77
+ # `DeprecationWarning`. Scheduled for removal in a future release.
78
+ _DEPRECATED_EXPERIMENTAL = {
79
+ "maxsim_matryoshka": "late_interaction_kernels.experimental",
80
+ "maxsim_xtr": "late_interaction_kernels.experimental",
81
+ "soft_maxsim": "late_interaction_kernels.experimental",
82
+ "smooth_maxsim": "late_interaction_kernels.experimental",
83
+ }
84
+
85
+ _DEPRECATED_FP8_HELPERS = {
86
+ "quantize_fp8_per_tensor": "late_interaction_kernels.fp8",
87
+ "quantize_fp8_per_token": "late_interaction_kernels.fp8",
88
+ "dequantize_fp8_per_tensor": "late_interaction_kernels.fp8",
89
+ "dequantize_fp8_per_token": "late_interaction_kernels.fp8",
90
+ }
91
+
92
+
93
+ def __getattr__(name: str):
94
+ """PEP 562 — re-export deprecated / moved symbols with a warning."""
95
+ import warnings
96
+
97
+ if name == "maxsim_forward":
98
+ warnings.warn(
99
+ "`late_interaction_kernels.maxsim_forward` is deprecated. Use "
100
+ "`maxsim_inference` for reranking, `maxsim` for gradients, or "
101
+ "import the primitive from `late_interaction_kernels.forward`.",
102
+ DeprecationWarning,
103
+ stacklevel=2,
104
+ )
105
+ if _HAS_TRITON:
106
+ from .forward import maxsim_forward as _mf
107
+
108
+ return _mf
109
+ return _needs_triton
110
+
111
+ if name == "maxsim_topk":
112
+ warnings.warn(
113
+ "`maxsim_topk` is deprecated; use `retrieve(Q, D, top_k=...)` "
114
+ "(same semantics, transparent CPU fallback). Still importable from "
115
+ "`late_interaction_kernels.topk`.",
116
+ DeprecationWarning,
117
+ stacklevel=2,
118
+ )
119
+ if _HAS_TRITON:
120
+ from .topk import maxsim_topk as _mt
121
+
122
+ return _mt
123
+ return _needs_triton
124
+
125
+ if name == "maxsim_residual_inference":
126
+ warnings.warn(
127
+ "`maxsim_residual_inference` is deprecated; `maxsim_residual` "
128
+ "auto-skips the argmax save when `Q.requires_grad=False`.",
129
+ DeprecationWarning,
130
+ stacklevel=2,
131
+ )
132
+ if _HAS_TRITON:
133
+ from .plaid import maxsim_residual_inference as _mri
134
+
135
+ return _mri
136
+ return _needs_triton
137
+
138
+ if name == "maxsim_varlen_inference":
139
+ warnings.warn(
140
+ "`maxsim_varlen_inference` is deprecated; `maxsim_varlen` "
141
+ "auto-skips the argmax save when neither input requires grad.",
142
+ DeprecationWarning,
143
+ stacklevel=2,
144
+ )
145
+ if _HAS_TRITON:
146
+ from .varlen import maxsim_varlen_inference as _mvi
147
+
148
+ return _mvi
149
+ return _needs_triton
150
+
151
+ if name in _DEPRECATED_EXPERIMENTAL:
152
+ new_home = _DEPRECATED_EXPERIMENTAL[name]
153
+ warnings.warn(
154
+ f"`late_interaction_kernels.{name}` moved to `{new_home}`. Use `from {new_home} import {name}`.",
155
+ DeprecationWarning,
156
+ stacklevel=2,
157
+ )
158
+ if _HAS_TRITON:
159
+ from . import experimental
160
+
161
+ return getattr(experimental, name)
162
+ return _needs_triton
163
+
164
+ if name in _DEPRECATED_FP8_HELPERS:
165
+ new_home = _DEPRECATED_FP8_HELPERS[name]
166
+ warnings.warn(
167
+ f"`late_interaction_kernels.{name}` moved to `{new_home}`. Use `from {new_home} import {name}`.",
168
+ DeprecationWarning,
169
+ stacklevel=2,
170
+ )
171
+ if _HAS_TRITON:
172
+ from . import fp8
173
+
174
+ return getattr(fp8, name)
175
+ return _needs_triton
176
+
177
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
178
+
179
+
180
+ __all__ = [
181
+ "__version__",
182
+ # high-level
183
+ "MaxSimScorer",
184
+ "retrieve",
185
+ "patch_pylate",
186
+ "unpatch_pylate",
187
+ # core MaxSim
188
+ "maxsim",
189
+ "maxsim_inference",
190
+ "maxsim_varlen",
191
+ # reranking on packed batches
192
+ "maxsim_inference_scatter",
193
+ # fused D-side head
194
+ "maxsim_from_hidden",
195
+ "maxsim_from_hidden_train",
196
+ # PLAID / ColBERTv2
197
+ "plaid_approx_score",
198
+ "maxsim_residual",
199
+ "maxsim_residual_varlen",
200
+ # FP8 inference
201
+ "maxsim_inference_fp8",
202
+ # configuration
203
+ "set_backward_method",
204
+ "get_backward_method",
205
+ "reference",
206
+ ]
@@ -0,0 +1,129 @@
1
+ """Autotune configs per GPU family.
2
+
3
+ Triton autotune runs each candidate once on the first call for a given key,
4
+ caches the winner, and reuses it forever after. We keep lists short — each
5
+ extra config costs one real launch.
6
+
7
+ Family rules of thumb (verified on H100 / A100 and conservative on the rest):
8
+ - Small ``d`` (≤ 128): prefer `BLOCK_Q=32-64, BLOCK_D=64-128`.
9
+ - Large ``d`` (≥ 512): shrink blocks so the fp16 `Q`/`D` tiles plus the fp32
10
+ `S` tile fit in the SM's shared-memory budget.
11
+ - Hopper loves `num_stages ≥ 3` (warp specialization + async copy).
12
+ - Ampere / Ada are happiest with `num_stages=2`.
13
+
14
+ Per-family SRAM budgets (KiB of shared memory the kernel can actually use):
15
+ - Hopper (H100 / H200): 228
16
+ - Ampere (A100): 164
17
+ - Ampere consumer (3090, A10): 100
18
+ - Ada (L4, L40, RTX 4090): 100
19
+ - Unknown / older: 48 (safe floor)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import inspect
25
+
26
+ import triton
27
+
28
+ from ._utils import detect_gpu
29
+
30
+ # Warp specialization (FA-3 style) requires Triton 3.2+. The
31
+ # ``num_consumer_groups`` / ``num_buffers_warp_spec`` kwargs on
32
+ # ``triton.Config`` opt a kernel into producer-consumer warp specialization
33
+ # so loads overlap cleanly with ``tl.dot``. We feature-detect so we still
34
+ # run on older Triton.
35
+ try:
36
+ _CFG_PARAMS = set(inspect.signature(triton.Config).parameters)
37
+ except (TypeError, ValueError): # pragma: no cover
38
+ _CFG_PARAMS = set()
39
+ _HAS_WARP_SPEC = {"num_consumer_groups", "num_buffers_warp_spec"} <= _CFG_PARAMS
40
+
41
+
42
+ def _cfg(kwargs, *, num_warps, num_stages, warp_spec=False):
43
+ """Build a ``triton.Config`` and quietly opt-in to warp specialization
44
+ when the running Triton supports it.
45
+ """
46
+ extras = {}
47
+ if warp_spec and _HAS_WARP_SPEC:
48
+ extras["num_consumer_groups"] = 2
49
+ extras["num_buffers_warp_spec"] = num_stages
50
+ return triton.Config(kwargs, num_warps=num_warps, num_stages=num_stages, **extras)
51
+
52
+
53
+ def _small_d_hopper():
54
+ return [
55
+ _cfg({"BLOCK_Q": 32, "BLOCK_D": 64}, num_warps=4, num_stages=3),
56
+ _cfg({"BLOCK_Q": 32, "BLOCK_D": 128}, num_warps=8, num_stages=3),
57
+ _cfg({"BLOCK_Q": 64, "BLOCK_D": 64}, num_warps=4, num_stages=3),
58
+ _cfg({"BLOCK_Q": 64, "BLOCK_D": 128}, num_warps=8, num_stages=3),
59
+ _cfg({"BLOCK_Q": 128, "BLOCK_D": 64}, num_warps=8, num_stages=2),
60
+ _cfg({"BLOCK_Q": 128, "BLOCK_D": 128}, num_warps=8, num_stages=2),
61
+ # Warp-specialized shortlist: producer warp group streams Q/D tiles
62
+ # into shared memory while consumer group(s) run back-to-back
63
+ # ``tl.dot`` + running-max. No-ops on Triton < 3.2.
64
+ _cfg({"BLOCK_Q": 64, "BLOCK_D": 128}, num_warps=8, num_stages=3, warp_spec=True),
65
+ _cfg({"BLOCK_Q": 128, "BLOCK_D": 128}, num_warps=8, num_stages=3, warp_spec=True),
66
+ ]
67
+
68
+
69
+ def _small_d_ampere():
70
+ """Works on A100, A10, A40, 3090, and is a safe default for Ada (L4, L40, 4090)."""
71
+ return [
72
+ _cfg({"BLOCK_Q": 32, "BLOCK_D": 64}, num_warps=4, num_stages=2),
73
+ _cfg({"BLOCK_Q": 32, "BLOCK_D": 128}, num_warps=8, num_stages=2),
74
+ _cfg({"BLOCK_Q": 64, "BLOCK_D": 64}, num_warps=4, num_stages=2),
75
+ _cfg({"BLOCK_Q": 64, "BLOCK_D": 128}, num_warps=8, num_stages=2),
76
+ _cfg({"BLOCK_Q": 128, "BLOCK_D": 64}, num_warps=8, num_stages=1),
77
+ ]
78
+
79
+
80
+ def _large_d_configs():
81
+ """Small-block configs for d ≥ 512 — fit any GPU, any SM."""
82
+ return [
83
+ _cfg({"BLOCK_Q": 16, "BLOCK_D": 16}, num_warps=2, num_stages=2),
84
+ _cfg({"BLOCK_Q": 16, "BLOCK_D": 32}, num_warps=2, num_stages=2),
85
+ _cfg({"BLOCK_Q": 32, "BLOCK_D": 16}, num_warps=2, num_stages=2),
86
+ _cfg({"BLOCK_Q": 32, "BLOCK_D": 32}, num_warps=4, num_stages=2),
87
+ _cfg({"BLOCK_Q": 32, "BLOCK_D": 64}, num_warps=4, num_stages=2),
88
+ ]
89
+
90
+
91
+ _SRAM_KIB_BY_FAMILY = {
92
+ "hopper": 228,
93
+ "a100": 164,
94
+ "ampere": 100,
95
+ "ada": 100,
96
+ "generic": 48,
97
+ }
98
+
99
+
100
+ def forward_configs():
101
+ gpu = detect_gpu()
102
+ base = _large_d_configs()
103
+ if gpu == "hopper":
104
+ return base + _small_d_hopper()
105
+ if gpu in ("a100", "ampere", "ada"):
106
+ return base + _small_d_ampere()
107
+ return base # minimal safe shortlist for unknown GPUs
108
+
109
+
110
+ def prune_forward(configs, named_args, **kwargs):
111
+ """Drop configs that overflow shared memory or are oversized for the problem."""
112
+ Lq = named_args.get("Lq", 32)
113
+ d = named_args.get("d", 128)
114
+ gpu = detect_gpu()
115
+ # Reserve 8 KiB for Triton scratch; the rest is ours.
116
+ sram_budget = (_SRAM_KIB_BY_FAMILY.get(gpu, 48) - 8) * 1024
117
+
118
+ keep = []
119
+ for cfg in configs:
120
+ bq, bd = cfg.kwargs["BLOCK_Q"], cfg.kwargs["BLOCK_D"]
121
+ # fp16/bf16 Q tile + fp16/bf16 D tile + fp32 S tile.
122
+ need = (bq * d + bd * d) * 2 + bq * bd * 4
123
+ if need > sram_budget:
124
+ continue
125
+ if bq > 2 * Lq:
126
+ continue
127
+ keep.append(cfg)
128
+ # Always return at least two configs so autotune has something to compare.
129
+ return keep or configs[:2]
@@ -0,0 +1,50 @@
1
+ """Small helpers shared across kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+
7
+ import torch
8
+
9
+
10
+ def next_pow2(x: int) -> int:
11
+ """Smallest power of two >= x. `next_pow2(0)` returns 1."""
12
+ if x <= 1:
13
+ return 1
14
+ return 1 << (x - 1).bit_length()
15
+
16
+
17
+ @functools.lru_cache(maxsize=1)
18
+ def detect_gpu() -> str:
19
+ """Return a short GPU family string: 'hopper' | 'a100' | 'ada' | 'ampere' | 'generic'."""
20
+ if not torch.cuda.is_available():
21
+ return "generic"
22
+ name = torch.cuda.get_device_name().lower()
23
+ if "h100" in name or "h200" in name:
24
+ return "hopper"
25
+ if "a100" in name:
26
+ return "a100"
27
+ if "l4" in name or "l40" in name or "rtx 40" in name:
28
+ return "ada"
29
+ if "3090" in name or "a10" in name or "a40" in name:
30
+ return "ampere"
31
+ return "generic"
32
+
33
+
34
+ def ensure_contiguous_last(x: torch.Tensor) -> torch.Tensor:
35
+ """Make sure the last dim is contiguous — cheap path for most inputs."""
36
+ if x.stride(-1) == 1:
37
+ return x
38
+ return x.contiguous()
39
+
40
+
41
+ def pick_compute_dtype(Q: torch.Tensor, D: torch.Tensor) -> torch.dtype:
42
+ """Pick the compute dtype for `tl.dot`.
43
+
44
+ We honor user intent: if both tensors are fp16/bf16, dot runs in that dtype
45
+ with fp32 accumulator. If either is fp32 we fall back to fp16 on the tile
46
+ (fp32 GEMM doesn't go through tensor cores on H100 anyway).
47
+ """
48
+ if Q.dtype == torch.bfloat16 or D.dtype == torch.bfloat16:
49
+ return torch.bfloat16
50
+ return torch.float16
@@ -0,0 +1,244 @@
1
+ """User-facing autograd wrapper for the fused MaxSim kernel."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import warnings
7
+
8
+ import torch
9
+
10
+ from .backward import maxsim_backward
11
+ from .backward_unified import maxsim_backward_unified
12
+ from .forward import _run_forward, maxsim_forward
13
+
14
+ _BACKWARD_METHOD = "auto" # module-level toggle, flipped by `set_backward_method`
15
+
16
+ _VALID_METHODS = ("auto", "atomic", "csr", "unified")
17
+
18
+ # One-shot flag so we don't spam the user's logs if they happen to pass
19
+ # unnormalized inputs inside a tight training loop.
20
+ _WARNED_UNNORMALIZED = False
21
+
22
+
23
+ def set_backward_method(method: str) -> None:
24
+ """Set the process-wide default ``grad_D`` path.
25
+
26
+ Prefer the per-call ``backward=`` kwarg on :func:`maxsim` and
27
+ :class:`~late_interaction_kernels.MaxSimScorer`. This global is kept
28
+ for back-compat and for pinning a single method across a benchmark run.
29
+
30
+ Values:
31
+
32
+ * ``"auto"`` — ``"unified"`` for almost every shape; ``"csr"`` for
33
+ very high ``grad_D`` contention (``Nq ≥ 256 ∧ Nd ≥ 256 ∧ Lq ≤ 64``).
34
+ * ``"unified"`` — single-pass fused ``grad_Q + grad_D`` kernel.
35
+ * ``"csr"`` — scatter-free bucketed reduction; bitwise-deterministic.
36
+ * ``"atomic"`` — legacy two-pass with fp32 ``tl.atomic_add``.
37
+ """
38
+ global _BACKWARD_METHOD
39
+ if method not in _VALID_METHODS:
40
+ raise ValueError(f"method must be one of {_VALID_METHODS}, got {method!r}")
41
+ _BACKWARD_METHOD = method
42
+
43
+
44
+ def get_backward_method() -> str:
45
+ return _BACKWARD_METHOD
46
+
47
+
48
+ def _maybe_warn_unnormalized(Q: torch.Tensor) -> None:
49
+ """Warn once when ``normalize=False`` is paired with non-normalized Q.
50
+
51
+ ColBERT / ColPali / LateOn always score L2-normalized tokens. Calling
52
+ ``maxsim`` on raw encoder outputs silently produces different score
53
+ scales than PyLate. Silence with ``LIK_SUPPRESS_NORM_WARN=1``.
54
+ """
55
+ global _WARNED_UNNORMALIZED
56
+ if _WARNED_UNNORMALIZED or os.environ.get("LIK_SUPPRESS_NORM_WARN", "0") == "1":
57
+ return
58
+ # Cheap sanity check: a handful of token norms.
59
+ with torch.no_grad():
60
+ sample = Q.detach()
61
+ # Flatten leading dims, inspect up to the first 64 tokens.
62
+ sample = sample.reshape(-1, sample.shape[-1])[:64]
63
+ if sample.numel() == 0:
64
+ return
65
+ norms = sample.float().norm(dim=-1)
66
+ med = norms.median().item()
67
+ if not (0.9 <= med <= 1.1):
68
+ _WARNED_UNNORMALIZED = True
69
+ warnings.warn(
70
+ f"late-interaction-kernels: `maxsim(..., normalize=False)` but Q's median L2 norm "
71
+ f"is {med:.3f} (ColBERT-style models expect ≈1.0). Pass `normalize=True` to fuse "
72
+ "the L2-norm into the kernel, or pre-normalize with `F.normalize(Q, dim=-1)`. "
73
+ "Silence with `LIK_SUPPRESS_NORM_WARN=1`.",
74
+ UserWarning,
75
+ stacklevel=3,
76
+ )
77
+
78
+
79
+ class _MaxSimFn(torch.autograd.Function):
80
+ """Fused MaxSim with saved argmax, 3-D inputs."""
81
+
82
+ @staticmethod
83
+ def forward(ctx, Q, D, q_mask, d_mask, normalize, backward_method):
84
+ scores, argmax = _run_forward(Q, D, q_mask, d_mask, save_argmax=True, normalize=normalize)
85
+ ctx.save_for_backward(Q, D, argmax, q_mask, d_mask)
86
+ ctx.backward_method = backward_method
87
+ ctx.normalize = normalize
88
+ return scores
89
+
90
+ @staticmethod
91
+ def backward(ctx, grad_scores):
92
+ Q, D, argmax, q_mask, d_mask = ctx.saved_tensors
93
+ grad_scores = grad_scores.contiguous().to(torch.float32)
94
+
95
+ # `auto` -> `unified` for typical training shapes; `csr` only when
96
+ # `grad_D` contention is very high (large square batches, short Lq).
97
+ method = ctx.backward_method
98
+ if method == "auto":
99
+ Nq, Lq, _ = Q.shape
100
+ Nd = D.shape[0]
101
+ high_contention = Nq >= 256 and Nd >= 256 and Lq <= 64
102
+ method = "csr" if high_contention else "unified"
103
+
104
+ def _bwd(Qt, Dt):
105
+ if method == "unified":
106
+ return maxsim_backward_unified(grad_scores, Qt, Dt, argmax, q_mask=q_mask, method="atomic")
107
+ return maxsim_backward(
108
+ grad_scores,
109
+ Qt,
110
+ Dt,
111
+ argmax,
112
+ q_mask,
113
+ d_mask,
114
+ method=method,
115
+ )
116
+
117
+ if ctx.normalize:
118
+ # The forward computed scores against Q_hat = Q / ||Q|| and D_hat = D / ||D||.
119
+ # We need grad w.r.t. the *unnormalized* Q and D. We get that by
120
+ # (a) running the existing backward against the normalized tensors to get
121
+ # grad_Q_hat, grad_D_hat, then (b) applying the L2-normalize Jacobian.
122
+ q_norm = torch.linalg.vector_norm(Q, dim=-1, keepdim=True).clamp_min(1e-6)
123
+ d_norm = torch.linalg.vector_norm(D, dim=-1, keepdim=True).clamp_min(1e-6)
124
+ Q_hat = Q / q_norm
125
+ D_hat = D / d_norm
126
+ grad_Qh, grad_Dh = _bwd(Q_hat, D_hat)
127
+ # d Qhat / d Q = (I - Qhat Qhat^T) / ||Q||
128
+ grad_Q = (grad_Qh - (grad_Qh * Q_hat).sum(-1, keepdim=True) * Q_hat) / q_norm
129
+ grad_D = (grad_Dh - (grad_Dh * D_hat).sum(-1, keepdim=True) * D_hat) / d_norm
130
+ else:
131
+ grad_Q, grad_D = _bwd(Q, D)
132
+ # masks, normalize, backward_method receive no gradient
133
+ return grad_Q, grad_D, None, None, None, None
134
+
135
+
136
+ def maxsim(
137
+ Q: torch.Tensor,
138
+ D: torch.Tensor,
139
+ q_mask: torch.Tensor | None = None,
140
+ d_mask: torch.Tensor | None = None,
141
+ *,
142
+ normalize: bool = False,
143
+ backward: str | None = None,
144
+ ) -> torch.Tensor:
145
+ """Differentiable fused MaxSim. Drop-in for PyLate's ``colbert_scores``.
146
+
147
+ Args:
148
+ Q: ``[Nq, Lq, d]`` or ``[Lq, d]``.
149
+ D: ``[Nd, Ld, d]`` or ``[Ld, d]``.
150
+ q_mask, d_mask: bool tensors (``True`` = valid token).
151
+ normalize: L2-normalize Q and D per-token inside the kernel. Set to
152
+ ``True`` for ColBERT / ColPali / LateOn-style scoring.
153
+ backward: per-call override of the ``grad_D`` strategy
154
+ (``"auto" | "unified" | "csr" | "atomic"``). ``None`` defers
155
+ to :func:`set_backward_method`.
156
+
157
+ Returns:
158
+ scores: ``[Nq, Nd]`` fp32, squeezed to match 2-D inputs.
159
+
160
+ Inputs can be fp16 / bf16 / fp32 (fp32 accumulator). Gradients flow
161
+ into Q and D; masks are non-differentiable.
162
+ """
163
+ q_was_2d = Q.dim() == 2
164
+ d_was_2d = D.dim() == 2
165
+ if q_was_2d:
166
+ Q = Q.unsqueeze(0)
167
+ if d_was_2d:
168
+ D = D.unsqueeze(0)
169
+ if q_mask is not None and q_mask.dim() == 1:
170
+ q_mask = q_mask.unsqueeze(0)
171
+ if d_mask is not None and d_mask.dim() == 1:
172
+ d_mask = d_mask.unsqueeze(0)
173
+
174
+ # Shape / device contract — fail fast with a clear message so user code
175
+ # doesn't silently corrupt memory or produce garbage scores.
176
+ if Q.shape[-1] != D.shape[-1]:
177
+ raise ValueError(
178
+ f"Q and D must share the embedding dim; got Q.shape[-1]={Q.shape[-1]} "
179
+ f"vs D.shape[-1]={D.shape[-1]}."
180
+ )
181
+ if Q.device != D.device:
182
+ raise ValueError(
183
+ f"Q and D must be on the same device; got Q.device={Q.device} vs D.device={D.device}."
184
+ )
185
+ if q_mask is not None and q_mask.device != Q.device:
186
+ raise ValueError(f"q_mask must be on the same device as Q; got {q_mask.device} vs {Q.device}.")
187
+ if d_mask is not None and d_mask.device != D.device:
188
+ raise ValueError(f"d_mask must be on the same device as D; got {d_mask.device} vs {D.device}.")
189
+
190
+ if backward is None:
191
+ method = _BACKWARD_METHOD
192
+ elif backward not in _VALID_METHODS:
193
+ raise ValueError(f"backward= must be one of {_VALID_METHODS} or None, got {backward!r}")
194
+ else:
195
+ method = backward
196
+
197
+ if not normalize:
198
+ _maybe_warn_unnormalized(Q)
199
+
200
+ Q = Q.contiguous()
201
+ D = D.contiguous()
202
+ q_mask_i8 = q_mask.contiguous().to(torch.int8) if q_mask is not None else None
203
+ d_mask_i8 = d_mask.contiguous().to(torch.int8) if d_mask is not None else None
204
+
205
+ scores = _MaxSimFn.apply(Q, D, q_mask_i8, d_mask_i8, normalize, method)
206
+
207
+ if q_was_2d and d_was_2d:
208
+ return scores.reshape(())
209
+ if q_was_2d:
210
+ return scores.squeeze(0)
211
+ if d_was_2d:
212
+ return scores.squeeze(-1)
213
+ return scores
214
+
215
+
216
+ def maxsim_inference(
217
+ Q: torch.Tensor,
218
+ D: torch.Tensor,
219
+ q_mask: torch.Tensor | None = None,
220
+ d_mask: torch.Tensor | None = None,
221
+ *,
222
+ normalize: bool = False,
223
+ ) -> torch.Tensor:
224
+ """Inference-only MaxSim — like :func:`maxsim` but no saved argmax."""
225
+ if Q.shape[-1] != D.shape[-1]:
226
+ raise ValueError(
227
+ f"Q and D must share the embedding dim; got Q.shape[-1]={Q.shape[-1]} "
228
+ f"vs D.shape[-1]={D.shape[-1]}."
229
+ )
230
+ if Q.device != D.device:
231
+ raise ValueError(
232
+ f"Q and D must be on the same device; got Q.device={Q.device} vs D.device={D.device}."
233
+ )
234
+ if not normalize:
235
+ _maybe_warn_unnormalized(Q)
236
+ scores, _ = maxsim_forward(
237
+ Q,
238
+ D,
239
+ q_mask=q_mask,
240
+ d_mask=d_mask,
241
+ save_argmax=False,
242
+ normalize=normalize,
243
+ )
244
+ return scores