rat-embed 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ .venv/
5
+ *.egg-info/
6
+
7
+ # Data (generated at runtime)
8
+ data/*.npy
9
+ data/*.json
10
+
11
+ # OS
12
+ .DS_Store
13
+ *.Zone.Identifier
14
+ Thumbs.db
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 sojir
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,99 @@
1
+ Metadata-Version: 2.4
2
+ Name: rat-embed
3
+ Version: 0.1.0
4
+ Summary: Relative Anchor Translation: zero-shot embedding space translation via similarity profiles
5
+ Project-URL: Homepage, https://github.com/jiro-prog/rat-experiment
6
+ Project-URL: Repository, https://github.com/jiro-prog/rat-experiment
7
+ Project-URL: Paper, https://zenodo.org/records/19401277
8
+ Project-URL: Issues, https://github.com/jiro-prog/rat-experiment/issues
9
+ Author: So Jiro
10
+ License-Expression: MIT
11
+ License-File: LICENSE
12
+ Keywords: cross-model retrieval,embedding,model translation,relative representation,zero-shot
13
+ Classifier: Development Status :: 3 - Alpha
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Requires-Python: >=3.10
22
+ Requires-Dist: numpy>=1.24
23
+ Provides-Extra: dev
24
+ Requires-Dist: pytest>=7.0; extra == 'dev'
25
+ Requires-Dist: ruff>=0.1; extra == 'dev'
26
+ Provides-Extra: models
27
+ Requires-Dist: sentence-transformers>=2.2; extra == 'models'
28
+ Description-Content-Type: text/markdown
29
+
30
+ # RAT — Relative Anchor Translation
31
+
32
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.19401277.svg)](https://zenodo.org/records/19401277)
33
+
34
+ Zero-shot embedding space translation using relative distances to shared anchors. No additional training required.
35
+
36
+ ## Install
37
+
38
+ ```bash
39
+ pip install -e . # core (numpy only)
40
+ pip install -e ".[models]" # + sentence-transformers for fit()
41
+ pip install -e ".[dev]" # + pytest, ruff
42
+ ```
43
+
44
+ ## Quick Start
45
+
46
+ ```python
47
+ import numpy as np
48
+ from rat import RATranslator
49
+
50
+ # 1. Prepare anchor embeddings from both models (K anchors, L2-normalized)
51
+ anchor_a = ... # (K, D_a) from model A
52
+ anchor_b = ... # (K, D_b) from model B
53
+
54
+ # 2. Fit the translator
55
+ translator = RATranslator(kernel="poly").fit_embeddings(anchor_a, anchor_b)
56
+
57
+ # 3. Transform & retrieve
58
+ query_emb = ... # (N, D_a) from model A
59
+ db_emb = ... # (M, D_b) from model B
60
+ results = translator.retrieve(query_emb, db_emb, top_k=10)
61
+ # results["indices"] → (N, 10) nearest neighbor indices
62
+ # results["scores"] → (N, 10) cosine similarity scores
63
+ ```
64
+
65
+ Or transform individually for more control:
66
+
67
+ ```python
68
+ q_rel = translator.transform(query_emb, "a") # query side (no z-score)
69
+ d_rel = translator.transform(db_emb, "b") # db side (z-score applied)
70
+ d_rel = translator.transform(db_emb, "b", role="query") # override: skip z-score
71
+ ```
72
+
73
+ ## Advanced: RATHub (multi-model)
74
+
75
+ ```python
76
+ from rat import RATHub
77
+
78
+ hub = RATHub(kernel="poly")
79
+ hub.set_anchors("minilm", anchor_minilm) # (K, 384)
80
+ hub.set_anchors("e5", anchor_e5) # (K, 1024)
81
+ hub.set_anchors("bge", anchor_bge) # (K, 384)
82
+
83
+ # Transform from any model
84
+ q = hub.transform("minilm", query_emb, role="query")
85
+ d = hub.transform("e5", db_emb, role="db")
86
+
87
+ # Or use retrieve directly
88
+ results = hub.retrieve(query_emb, db_emb, "minilm", "e5", top_k=10)
89
+ ```
90
+
91
+ ## Paper
92
+
93
+ See the [Zenodo record](https://zenodo.org/records/19401277) for the full experiment report.
94
+
95
+ Experiment reproduction code is in `experiments/` (unchanged from the original research).
96
+
97
+ ## License
98
+
99
+ MIT
@@ -0,0 +1,70 @@
1
+ # RAT — Relative Anchor Translation
2
+
3
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.19401277.svg)](https://zenodo.org/records/19401277)
4
+
5
+ Zero-shot embedding space translation using relative distances to shared anchors. No additional training required.
6
+
7
+ ## Install
8
+
9
+ ```bash
10
+ pip install -e . # core (numpy only)
11
+ pip install -e ".[models]" # + sentence-transformers for fit()
12
+ pip install -e ".[dev]" # + pytest, ruff
13
+ ```
14
+
15
+ ## Quick Start
16
+
17
+ ```python
18
+ import numpy as np
19
+ from rat import RATranslator
20
+
21
+ # 1. Prepare anchor embeddings from both models (K anchors, L2-normalized)
22
+ anchor_a = ... # (K, D_a) from model A
23
+ anchor_b = ... # (K, D_b) from model B
24
+
25
+ # 2. Fit the translator
26
+ translator = RATranslator(kernel="poly").fit_embeddings(anchor_a, anchor_b)
27
+
28
+ # 3. Transform & retrieve
29
+ query_emb = ... # (N, D_a) from model A
30
+ db_emb = ... # (M, D_b) from model B
31
+ results = translator.retrieve(query_emb, db_emb, top_k=10)
32
+ # results["indices"] → (N, 10) nearest neighbor indices
33
+ # results["scores"] → (N, 10) cosine similarity scores
34
+ ```
35
+
36
+ Or transform individually for more control:
37
+
38
+ ```python
39
+ q_rel = translator.transform(query_emb, "a") # query side (no z-score)
40
+ d_rel = translator.transform(db_emb, "b") # db side (z-score applied)
41
+ d_rel = translator.transform(db_emb, "b", role="query") # override: skip z-score
42
+ ```
43
+
44
+ ## Advanced: RATHub (multi-model)
45
+
46
+ ```python
47
+ from rat import RATHub
48
+
49
+ hub = RATHub(kernel="poly")
50
+ hub.set_anchors("minilm", anchor_minilm) # (K, 384)
51
+ hub.set_anchors("e5", anchor_e5) # (K, 1024)
52
+ hub.set_anchors("bge", anchor_bge) # (K, 384)
53
+
54
+ # Transform from any model
55
+ q = hub.transform("minilm", query_emb, role="query")
56
+ d = hub.transform("e5", db_emb, role="db")
57
+
58
+ # Or use retrieve directly
59
+ results = hub.retrieve(query_emb, db_emb, "minilm", "e5", top_k=10)
60
+ ```
61
+
62
+ ## Paper
63
+
64
+ See the [Zenodo record](https://zenodo.org/records/19401277) for the full experiment report.
65
+
66
+ Experiment reproduction code is in `experiments/` (unchanged from the original research).
67
+
68
+ ## License
69
+
70
+ MIT
@@ -0,0 +1,70 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "rat-embed"
7
+ version = "0.1.0"
8
+ description = "Relative Anchor Translation: zero-shot embedding space translation via similarity profiles"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ { name = "So Jiro" },
14
+ ]
15
+ keywords = [
16
+ "embedding",
17
+ "zero-shot",
18
+ "model translation",
19
+ "relative representation",
20
+ "cross-model retrieval",
21
+ ]
22
+ classifiers = [
23
+ "Development Status :: 3 - Alpha",
24
+ "Intended Audience :: Science/Research",
25
+ "License :: OSI Approved :: MIT License",
26
+ "Programming Language :: Python :: 3",
27
+ "Programming Language :: Python :: 3.10",
28
+ "Programming Language :: Python :: 3.11",
29
+ "Programming Language :: Python :: 3.12",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ ]
32
+ dependencies = [
33
+ "numpy>=1.24",
34
+ ]
35
+
36
+ [project.optional-dependencies]
37
+ models = ["sentence-transformers>=2.2"]
38
+ dev = ["pytest>=7.0", "ruff>=0.1"]
39
+
40
+ [project.urls]
41
+ Homepage = "https://github.com/jiro-prog/rat-experiment"
42
+ Repository = "https://github.com/jiro-prog/rat-experiment"
43
+ Paper = "https://zenodo.org/records/19401277"
44
+ Issues = "https://github.com/jiro-prog/rat-experiment/issues"
45
+
46
+ [tool.hatch.build.targets.sdist]
47
+ exclude = [
48
+ "data/",
49
+ "experiments/",
50
+ "examples/",
51
+ "results/",
52
+ "figures/",
53
+ "paper/",
54
+ "scripts/",
55
+ "src/",
56
+ "*.npy",
57
+ "config.py",
58
+ "requirements.txt",
59
+ "RAT-phase0-experiment.md",
60
+ ]
61
+
62
+ [tool.hatch.build.targets.wheel]
63
+ packages = ["rat"]
64
+
65
+ [tool.pytest.ini_options]
66
+ testpaths = ["tests"]
67
+
68
+ [tool.ruff]
69
+ target-version = "py310"
70
+ line-length = 100
@@ -0,0 +1,7 @@
1
+ """RAT: Relative Anchor Translation for zero-shot embedding space translation."""
2
+
3
+ from rat.translator import RATranslator
4
+ from rat.hub import RATHub
5
+
6
+ __all__ = ["RATranslator", "RATHub"]
7
+ __version__ = "0.1.0"
@@ -0,0 +1,293 @@
1
+ """RATHub: multi-model relative anchor translation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import warnings
7
+
8
+ import numpy as np
9
+
10
+ from rat.kernels import get_kernel
11
+ from rat.normalize import compute_sim_mean, normalize_zscore, recommend_zscore
12
+
13
+
14
+ class RATHub:
15
+ """Manage N embedding models and translate between any pair via RAT.
16
+
17
+ Each model's anchor embeddings are registered once. After that,
18
+ any model's embeddings can be transformed to the shared relative space.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ kernel: str = "poly",
24
+ kernel_params: dict | None = None,
25
+ normalize: str = "auto",
26
+ normalize_threshold: float = 0.15,
27
+ normalize_harmful_threshold: float = 0.65,
28
+ verbose: bool = False,
29
+ ):
30
+ self._kernel_name = kernel
31
+ self._kernel_params = kernel_params or {}
32
+ self._kernel_fn = get_kernel(kernel, kernel_params)
33
+ self._normalize = normalize
34
+ self._normalize_threshold = normalize_threshold
35
+ self._normalize_harmful_threshold = normalize_harmful_threshold
36
+ self._verbose = verbose
37
+
38
+ # Per-model state: {model_name: {...}}
39
+ self._models: dict[str, dict] = {}
40
+ self._anchor_k: int | None = None # enforced anchor count
41
+
42
+ # ── Registration ──
43
+
44
+ def set_anchors(
45
+ self,
46
+ model_name: str,
47
+ anchor_embeddings: np.ndarray,
48
+ ) -> None:
49
+ """Register anchor embeddings for a model.
50
+
51
+ Parameters
52
+ ----------
53
+ model_name : identifier for this model
54
+ anchor_embeddings : (K, D) L2-normalized
55
+ """
56
+ _check_l2_normalized(anchor_embeddings)
57
+
58
+ K = anchor_embeddings.shape[0]
59
+ if self._anchor_k is None:
60
+ self._anchor_k = K
61
+ elif K != self._anchor_k:
62
+ raise ValueError(
63
+ f"Anchor count mismatch: expected {self._anchor_k}, got {K} for '{model_name}'"
64
+ )
65
+
66
+ sim_mean = compute_sim_mean(anchor_embeddings)
67
+
68
+ self._models[model_name] = {
69
+ "anchors": anchor_embeddings,
70
+ "sim_mean": sim_mean,
71
+ }
72
+
73
+ if self._verbose:
74
+ rec = recommend_zscore(sim_mean, self._normalize_threshold, self._normalize_harmful_threshold)
75
+ print(f"[RAT] {model_name} sim_mean={sim_mean:.3f} → z-score: {rec}")
76
+ print(
77
+ f"[RAT] Anchor shape: {model_name}="
78
+ f"({anchor_embeddings.shape[0]}, {anchor_embeddings.shape[1]})"
79
+ )
80
+
81
+ # ── Transform ──
82
+
83
+ def transform(
84
+ self,
85
+ model_name: str,
86
+ embeddings: np.ndarray,
87
+ role: str = "auto",
88
+ ) -> np.ndarray:
89
+ """Transform embeddings to relative representation.
90
+
91
+ Parameters
92
+ ----------
93
+ model_name : which model produced these embeddings
94
+ embeddings : (N, D) L2-normalized
95
+ role : "query" (skip z-score), "db" (apply z-score per normalize setting),
96
+ "auto" (treat as db)
97
+
98
+ Returns
99
+ -------
100
+ (N, K) relative representation
101
+ """
102
+ if model_name not in self._models:
103
+ raise RuntimeError(
104
+ f"Model '{model_name}' not registered. Call set_anchors() first."
105
+ )
106
+ if role not in ("query", "db", "auto"):
107
+ raise ValueError(f"role must be 'query', 'db', or 'auto', got '{role}'")
108
+
109
+ _check_l2_normalized(embeddings)
110
+
111
+ state = self._models[model_name]
112
+ anchors = state["anchors"]
113
+ rel = self._kernel_fn(embeddings, anchors) # (N, K)
114
+
115
+ if self._should_apply_zscore(model_name, role):
116
+ rel = normalize_zscore(rel)
117
+
118
+ return rel
119
+
120
+ def _should_apply_zscore(self, model_name: str, role: str) -> bool:
121
+ """Decide whether to apply z-score for this model/role combination."""
122
+ if role == "query":
123
+ return False
124
+
125
+ # role is "db" or "auto" → check normalize setting
126
+ state = self._models[model_name]
127
+ sim_mean = state["sim_mean"]
128
+
129
+ if self._normalize == "never":
130
+ return False
131
+ if self._normalize == "always":
132
+ return True
133
+
134
+ # "auto": use recommendation
135
+ rec = recommend_zscore(sim_mean, self._normalize_threshold, self._normalize_harmful_threshold)
136
+ should = rec == "recommended"
137
+
138
+ if self._verbose:
139
+ if rec == "harmful":
140
+ print(
141
+ f"[RAT] {model_name} sim_mean={sim_mean:.3f} >= "
142
+ f"harmful_threshold {self._normalize_harmful_threshold} "
143
+ f"→ z-score: harmful (skipping)"
144
+ )
145
+ elif should:
146
+ print(
147
+ f"[RAT] {model_name} sim_mean={sim_mean:.3f} >= "
148
+ f"{self._normalize_threshold} → applying z-score"
149
+ )
150
+ else:
151
+ print(
152
+ f"[RAT] {model_name} sim_mean={sim_mean:.3f} < "
153
+ f"{self._normalize_threshold} → z-score: not needed"
154
+ )
155
+
156
+ return should
157
+
158
+ # ── Retrieval ──
159
+
160
+ def retrieve(
161
+ self,
162
+ query_emb: np.ndarray,
163
+ db_emb: np.ndarray,
164
+ from_model: str,
165
+ to_model: str,
166
+ top_k: int = 10,
167
+ ) -> dict:
168
+ """Transform + cosine similarity retrieval.
169
+
170
+ Parameters
171
+ ----------
172
+ query_emb : (N, D_from) query embeddings from from_model
173
+ db_emb : (M, D_to) database embeddings from to_model
174
+ from_model : model name for queries
175
+ to_model : model name for database
176
+ top_k : number of results per query
177
+
178
+ Returns
179
+ -------
180
+ {"indices": (N, top_k), "scores": (N, top_k)}
181
+ """
182
+ q_rel = self.transform(from_model, query_emb, role="query")
183
+ d_rel = self.transform(to_model, db_emb, role="db")
184
+
185
+ # L2-normalize relative representations for cosine similarity
186
+ q_norm = q_rel / (np.linalg.norm(q_rel, axis=1, keepdims=True) + 1e-10)
187
+ d_norm = d_rel / (np.linalg.norm(d_rel, axis=1, keepdims=True) + 1e-10)
188
+
189
+ scores = q_norm @ d_norm.T # (N, M)
190
+ top_k = min(top_k, scores.shape[1])
191
+ indices = np.argsort(-scores, axis=1)[:, :top_k]
192
+ sorted_scores = np.take_along_axis(scores, indices, axis=1)
193
+
194
+ return {"indices": indices, "scores": sorted_scores}
195
+
196
+ # ── Compatibility ──
197
+
198
+ def estimate_compatibility(
199
+ self,
200
+ model_a: str,
201
+ model_b: str,
202
+ ) -> dict:
203
+ """Estimate RAT compatibility between two registered models.
204
+
205
+ Returns
206
+ -------
207
+ dict with sim_mean_a, sim_mean_b, max_sim_mean, estimated_recall_at_1,
208
+ z_score_recommendation, warnings
209
+ """
210
+ if model_a not in self._models:
211
+ raise RuntimeError(f"Model '{model_a}' not registered.")
212
+ if model_b not in self._models:
213
+ raise RuntimeError(f"Model '{model_b}' not registered.")
214
+
215
+ sm_a = self._models[model_a]["sim_mean"]
216
+ sm_b = self._models[model_b]["sim_mean"]
217
+ max_sm = max(sm_a, sm_b)
218
+ rec = recommend_zscore(max_sm, self._normalize_threshold, self._normalize_harmful_threshold)
219
+
220
+ warn_list: list[str] = []
221
+ if max_sm > self._normalize_harmful_threshold:
222
+ warn_list.append(
223
+ f"High sim_mean ({max_sm:.3f}) detected. "
224
+ "Z-score normalization may be harmful for this pair."
225
+ )
226
+
227
+ return {
228
+ "sim_mean_a": sm_a,
229
+ "sim_mean_b": sm_b,
230
+ "max_sim_mean": max_sm,
231
+ "estimated_recall_at_1": None, # Phase 2
232
+ "z_score_recommendation": rec,
233
+ "warnings": warn_list,
234
+ }
235
+
236
+ # ── Persistence ──
237
+
238
+ def save(self, path: str) -> None:
239
+ """Save hub state to .npz file."""
240
+ save_dict: dict[str, object] = {}
241
+
242
+ model_names = list(self._models.keys())
243
+ for name in model_names:
244
+ state = self._models[name]
245
+ save_dict[f"anchor_{name}"] = state["anchors"]
246
+
247
+ config = {
248
+ "kernel": self._kernel_name,
249
+ "kernel_params": self._kernel_params,
250
+ "normalize": self._normalize,
251
+ "normalize_threshold": self._normalize_threshold,
252
+ "normalize_harmful_threshold": self._normalize_harmful_threshold,
253
+ "model_names": model_names,
254
+ "sim_means": {n: self._models[n]["sim_mean"] for n in model_names},
255
+ "version": "0.1.0",
256
+ }
257
+ save_dict["config"] = np.array(json.dumps(config))
258
+
259
+ np.savez_compressed(path, **save_dict)
260
+
261
+ @classmethod
262
+ def load(cls, path: str) -> "RATHub":
263
+ """Load hub state from .npz file."""
264
+ if not path.endswith(".npz"):
265
+ path = path + ".npz"
266
+ data = np.load(path, allow_pickle=False)
267
+ config = json.loads(str(data["config"]))
268
+
269
+ hub = cls(
270
+ kernel=config["kernel"],
271
+ kernel_params=config.get("kernel_params"),
272
+ normalize=config["normalize"],
273
+ normalize_threshold=config["normalize_threshold"],
274
+ normalize_harmful_threshold=config.get("normalize_harmful_threshold", 0.65),
275
+ )
276
+
277
+ for name in config["model_names"]:
278
+ anchors = data[f"anchor_{name}"]
279
+ hub.set_anchors(name, anchors)
280
+ # Override sim_mean from saved config (avoids recomputation drift)
281
+ hub._models[name]["sim_mean"] = config["sim_means"][name]
282
+
283
+ return hub
284
+
285
+
286
+ def _check_l2_normalized(embeddings: np.ndarray) -> None:
287
+ """Warn if embeddings don't appear L2-normalized."""
288
+ norms = np.linalg.norm(embeddings, axis=1)
289
+ if not np.allclose(norms, 1.0, atol=0.01):
290
+ warnings.warn(
291
+ "Input embeddings may not be L2-normalized. RAT expects normalized embeddings.",
292
+ stacklevel=3,
293
+ )
@@ -0,0 +1,97 @@
1
+ """Kernel functions for computing relative representations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+
9
+
10
+ def poly_kernel(
11
+ X: np.ndarray,
12
+ A: np.ndarray,
13
+ degree: int = 2,
14
+ coef0: float = 1.0,
15
+ ) -> np.ndarray:
16
+ """Polynomial kernel: (X @ A.T + coef0) ** degree.
17
+
18
+ Parameters
19
+ ----------
20
+ X : (N, D) input embeddings
21
+ A : (K, D) anchor embeddings
22
+ degree : polynomial degree
23
+ coef0 : constant term
24
+
25
+ Returns
26
+ -------
27
+ (N, K) kernel matrix
28
+ """
29
+ return (X @ A.T + coef0) ** degree
30
+
31
+
32
+ def cosine_kernel(
33
+ X: np.ndarray,
34
+ A: np.ndarray,
35
+ ) -> np.ndarray:
36
+ """Cosine similarity kernel: X @ A.T (assumes L2-normalized inputs).
37
+
38
+ Parameters
39
+ ----------
40
+ X : (N, D) input embeddings, L2-normalized
41
+ A : (K, D) anchor embeddings, L2-normalized
42
+
43
+ Returns
44
+ -------
45
+ (N, K) similarity matrix
46
+ """
47
+ return X @ A.T
48
+
49
+
50
+ def rbf_kernel(
51
+ X: np.ndarray,
52
+ A: np.ndarray,
53
+ gamma: float | None = None,
54
+ ) -> np.ndarray:
55
+ """RBF (Gaussian) kernel: exp(-gamma * ||x - a||^2).
56
+
57
+ Parameters
58
+ ----------
59
+ X : (N, D) input embeddings
60
+ A : (K, D) anchor embeddings
61
+ gamma : kernel width. None uses 1/D
62
+
63
+ Returns
64
+ -------
65
+ (N, K) kernel matrix
66
+ """
67
+ if gamma is None:
68
+ gamma = 1.0 / X.shape[1]
69
+ # ||x - a||^2 = ||x||^2 + ||a||^2 - 2 x·a
70
+ X_sq = np.sum(X ** 2, axis=1, keepdims=True) # (N, 1)
71
+ A_sq = np.sum(A ** 2, axis=1, keepdims=True) # (K, 1)
72
+ dist_sq = X_sq + A_sq.T - 2.0 * (X @ A.T) # (N, K)
73
+ dist_sq = np.maximum(dist_sq, 0.0) # numerical safety
74
+ return np.exp(-gamma * dist_sq)
75
+
76
+
77
+ def get_kernel(name: str, params: dict | None = None):
78
+ """Factory: return a callable(X, A) -> np.ndarray for the named kernel.
79
+
80
+ Parameters
81
+ ----------
82
+ name : "poly", "cosine", or "rbf"
83
+ params : optional kwargs forwarded to the kernel function
84
+
85
+ Returns
86
+ -------
87
+ callable(X, A) -> np.ndarray
88
+ """
89
+ params = params or {}
90
+ kernels = {
91
+ "poly": poly_kernel,
92
+ "cosine": cosine_kernel,
93
+ "rbf": rbf_kernel,
94
+ }
95
+ if name not in kernels:
96
+ raise ValueError(f"Unknown kernel '{name}'. Choose from: {list(kernels.keys())}")
97
+ return partial(kernels[name], **params)