rat-embed 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rat_embed-0.1.0/.gitignore +14 -0
- rat_embed-0.1.0/LICENSE +21 -0
- rat_embed-0.1.0/PKG-INFO +99 -0
- rat_embed-0.1.0/README.md +70 -0
- rat_embed-0.1.0/pyproject.toml +70 -0
- rat_embed-0.1.0/rat/__init__.py +7 -0
- rat_embed-0.1.0/rat/hub.py +293 -0
- rat_embed-0.1.0/rat/kernels.py +97 -0
- rat_embed-0.1.0/rat/normalize.py +68 -0
- rat_embed-0.1.0/rat/sampling.py +48 -0
- rat_embed-0.1.0/rat/translator.py +209 -0
- rat_embed-0.1.0/tests/test_hub.py +112 -0
- rat_embed-0.1.0/tests/test_integration.py +246 -0
- rat_embed-0.1.0/tests/test_kernels.py +106 -0
- rat_embed-0.1.0/tests/test_normalize.py +88 -0
- rat_embed-0.1.0/tests/test_sampling.py +50 -0
- rat_embed-0.1.0/tests/test_translator.py +206 -0
rat_embed-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 sojir
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
rat_embed-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rat-embed
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Relative Anchor Translation: zero-shot embedding space translation via similarity profiles
|
|
5
|
+
Project-URL: Homepage, https://github.com/jiro-prog/rat-experiment
|
|
6
|
+
Project-URL: Repository, https://github.com/jiro-prog/rat-experiment
|
|
7
|
+
Project-URL: Paper, https://zenodo.org/records/19401277
|
|
8
|
+
Project-URL: Issues, https://github.com/jiro-prog/rat-experiment/issues
|
|
9
|
+
Author: So Jiro
|
|
10
|
+
License-Expression: MIT
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Keywords: cross-model retrieval,embedding,model translation,relative representation,zero-shot
|
|
13
|
+
Classifier: Development Status :: 3 - Alpha
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
|
+
Requires-Dist: numpy>=1.24
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
25
|
+
Requires-Dist: ruff>=0.1; extra == 'dev'
|
|
26
|
+
Provides-Extra: models
|
|
27
|
+
Requires-Dist: sentence-transformers>=2.2; extra == 'models'
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+
# RAT — Relative Anchor Translation
|
|
31
|
+
|
|
32
|
+
[](https://zenodo.org/records/19401277)
|
|
33
|
+
|
|
34
|
+
Zero-shot embedding space translation using relative distances to shared anchors. No additional training required.
|
|
35
|
+
|
|
36
|
+
## Install
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
pip install -e . # core (numpy only)
|
|
40
|
+
pip install -e ".[models]" # + sentence-transformers for fit()
|
|
41
|
+
pip install -e ".[dev]" # + pytest, ruff
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Quick Start
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
import numpy as np
|
|
48
|
+
from rat import RATranslator
|
|
49
|
+
|
|
50
|
+
# 1. Prepare anchor embeddings from both models (K anchors, L2-normalized)
|
|
51
|
+
anchor_a = ... # (K, D_a) from model A
|
|
52
|
+
anchor_b = ... # (K, D_b) from model B
|
|
53
|
+
|
|
54
|
+
# 2. Fit the translator
|
|
55
|
+
translator = RATranslator(kernel="poly").fit_embeddings(anchor_a, anchor_b)
|
|
56
|
+
|
|
57
|
+
# 3. Transform & retrieve
|
|
58
|
+
query_emb = ... # (N, D_a) from model A
|
|
59
|
+
db_emb = ... # (M, D_b) from model B
|
|
60
|
+
results = translator.retrieve(query_emb, db_emb, top_k=10)
|
|
61
|
+
# results["indices"] → (N, 10) nearest neighbor indices
|
|
62
|
+
# results["scores"] → (N, 10) cosine similarity scores
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
Or transform individually for more control:
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
q_rel = translator.transform(query_emb, "a") # query side (no z-score)
|
|
69
|
+
d_rel = translator.transform(db_emb, "b") # db side (z-score applied)
|
|
70
|
+
d_rel = translator.transform(db_emb, "b", role="query") # override: skip z-score
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## Advanced: RATHub (multi-model)
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
from rat import RATHub
|
|
77
|
+
|
|
78
|
+
hub = RATHub(kernel="poly")
|
|
79
|
+
hub.set_anchors("minilm", anchor_minilm) # (K, 384)
|
|
80
|
+
hub.set_anchors("e5", anchor_e5) # (K, 1024)
|
|
81
|
+
hub.set_anchors("bge", anchor_bge) # (K, 384)
|
|
82
|
+
|
|
83
|
+
# Transform from any model
|
|
84
|
+
q = hub.transform("minilm", query_emb, role="query")
|
|
85
|
+
d = hub.transform("e5", db_emb, role="db")
|
|
86
|
+
|
|
87
|
+
# Or use retrieve directly
|
|
88
|
+
results = hub.retrieve(query_emb, db_emb, "minilm", "e5", top_k=10)
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
## Paper
|
|
92
|
+
|
|
93
|
+
See the [Zenodo record](https://zenodo.org/records/19401277) for the full experiment report.
|
|
94
|
+
|
|
95
|
+
Experiment reproduction code is in `experiments/` (unchanged from the original research).
|
|
96
|
+
|
|
97
|
+
## License
|
|
98
|
+
|
|
99
|
+
MIT
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# RAT — Relative Anchor Translation
|
|
2
|
+
|
|
3
|
+
[](https://zenodo.org/records/19401277)
|
|
4
|
+
|
|
5
|
+
Zero-shot embedding space translation using relative distances to shared anchors. No additional training required.
|
|
6
|
+
|
|
7
|
+
## Install
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
pip install -e . # core (numpy only)
|
|
11
|
+
pip install -e ".[models]" # + sentence-transformers for fit()
|
|
12
|
+
pip install -e ".[dev]" # + pytest, ruff
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
## Quick Start
|
|
16
|
+
|
|
17
|
+
```python
|
|
18
|
+
import numpy as np
|
|
19
|
+
from rat import RATranslator
|
|
20
|
+
|
|
21
|
+
# 1. Prepare anchor embeddings from both models (K anchors, L2-normalized)
|
|
22
|
+
anchor_a = ... # (K, D_a) from model A
|
|
23
|
+
anchor_b = ... # (K, D_b) from model B
|
|
24
|
+
|
|
25
|
+
# 2. Fit the translator
|
|
26
|
+
translator = RATranslator(kernel="poly").fit_embeddings(anchor_a, anchor_b)
|
|
27
|
+
|
|
28
|
+
# 3. Transform & retrieve
|
|
29
|
+
query_emb = ... # (N, D_a) from model A
|
|
30
|
+
db_emb = ... # (M, D_b) from model B
|
|
31
|
+
results = translator.retrieve(query_emb, db_emb, top_k=10)
|
|
32
|
+
# results["indices"] → (N, 10) nearest neighbor indices
|
|
33
|
+
# results["scores"] → (N, 10) cosine similarity scores
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
Or transform individually for more control:
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
q_rel = translator.transform(query_emb, "a") # query side (no z-score)
|
|
40
|
+
d_rel = translator.transform(db_emb, "b") # db side (z-score applied)
|
|
41
|
+
d_rel = translator.transform(db_emb, "b", role="query") # override: skip z-score
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Advanced: RATHub (multi-model)
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from rat import RATHub
|
|
48
|
+
|
|
49
|
+
hub = RATHub(kernel="poly")
|
|
50
|
+
hub.set_anchors("minilm", anchor_minilm) # (K, 384)
|
|
51
|
+
hub.set_anchors("e5", anchor_e5) # (K, 1024)
|
|
52
|
+
hub.set_anchors("bge", anchor_bge) # (K, 384)
|
|
53
|
+
|
|
54
|
+
# Transform from any model
|
|
55
|
+
q = hub.transform("minilm", query_emb, role="query")
|
|
56
|
+
d = hub.transform("e5", db_emb, role="db")
|
|
57
|
+
|
|
58
|
+
# Or use retrieve directly
|
|
59
|
+
results = hub.retrieve(query_emb, db_emb, "minilm", "e5", top_k=10)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
## Paper
|
|
63
|
+
|
|
64
|
+
See the [Zenodo record](https://zenodo.org/records/19401277) for the full experiment report.
|
|
65
|
+
|
|
66
|
+
Experiment reproduction code is in `experiments/` (unchanged from the original research).
|
|
67
|
+
|
|
68
|
+
## License
|
|
69
|
+
|
|
70
|
+
MIT
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "rat-embed"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Relative Anchor Translation: zero-shot embedding space translation via similarity profiles"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "MIT"
|
|
11
|
+
requires-python = ">=3.10"
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "So Jiro" },
|
|
14
|
+
]
|
|
15
|
+
keywords = [
|
|
16
|
+
"embedding",
|
|
17
|
+
"zero-shot",
|
|
18
|
+
"model translation",
|
|
19
|
+
"relative representation",
|
|
20
|
+
"cross-model retrieval",
|
|
21
|
+
]
|
|
22
|
+
classifiers = [
|
|
23
|
+
"Development Status :: 3 - Alpha",
|
|
24
|
+
"Intended Audience :: Science/Research",
|
|
25
|
+
"License :: OSI Approved :: MIT License",
|
|
26
|
+
"Programming Language :: Python :: 3",
|
|
27
|
+
"Programming Language :: Python :: 3.10",
|
|
28
|
+
"Programming Language :: Python :: 3.11",
|
|
29
|
+
"Programming Language :: Python :: 3.12",
|
|
30
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
31
|
+
]
|
|
32
|
+
dependencies = [
|
|
33
|
+
"numpy>=1.24",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
[project.optional-dependencies]
|
|
37
|
+
models = ["sentence-transformers>=2.2"]
|
|
38
|
+
dev = ["pytest>=7.0", "ruff>=0.1"]
|
|
39
|
+
|
|
40
|
+
[project.urls]
|
|
41
|
+
Homepage = "https://github.com/jiro-prog/rat-experiment"
|
|
42
|
+
Repository = "https://github.com/jiro-prog/rat-experiment"
|
|
43
|
+
Paper = "https://zenodo.org/records/19401277"
|
|
44
|
+
Issues = "https://github.com/jiro-prog/rat-experiment/issues"
|
|
45
|
+
|
|
46
|
+
[tool.hatch.build.targets.sdist]
|
|
47
|
+
exclude = [
|
|
48
|
+
"data/",
|
|
49
|
+
"experiments/",
|
|
50
|
+
"examples/",
|
|
51
|
+
"results/",
|
|
52
|
+
"figures/",
|
|
53
|
+
"paper/",
|
|
54
|
+
"scripts/",
|
|
55
|
+
"src/",
|
|
56
|
+
"*.npy",
|
|
57
|
+
"config.py",
|
|
58
|
+
"requirements.txt",
|
|
59
|
+
"RAT-phase0-experiment.md",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
[tool.hatch.build.targets.wheel]
|
|
63
|
+
packages = ["rat"]
|
|
64
|
+
|
|
65
|
+
[tool.pytest.ini_options]
|
|
66
|
+
testpaths = ["tests"]
|
|
67
|
+
|
|
68
|
+
[tool.ruff]
|
|
69
|
+
target-version = "py310"
|
|
70
|
+
line-length = 100
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""RATHub: multi-model relative anchor translation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from rat.kernels import get_kernel
|
|
11
|
+
from rat.normalize import compute_sim_mean, normalize_zscore, recommend_zscore
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RATHub:
|
|
15
|
+
"""Manage N embedding models and translate between any pair via RAT.
|
|
16
|
+
|
|
17
|
+
Each model's anchor embeddings are registered once. After that,
|
|
18
|
+
any model's embeddings can be transformed to the shared relative space.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
kernel: str = "poly",
|
|
24
|
+
kernel_params: dict | None = None,
|
|
25
|
+
normalize: str = "auto",
|
|
26
|
+
normalize_threshold: float = 0.15,
|
|
27
|
+
normalize_harmful_threshold: float = 0.65,
|
|
28
|
+
verbose: bool = False,
|
|
29
|
+
):
|
|
30
|
+
self._kernel_name = kernel
|
|
31
|
+
self._kernel_params = kernel_params or {}
|
|
32
|
+
self._kernel_fn = get_kernel(kernel, kernel_params)
|
|
33
|
+
self._normalize = normalize
|
|
34
|
+
self._normalize_threshold = normalize_threshold
|
|
35
|
+
self._normalize_harmful_threshold = normalize_harmful_threshold
|
|
36
|
+
self._verbose = verbose
|
|
37
|
+
|
|
38
|
+
# Per-model state: {model_name: {...}}
|
|
39
|
+
self._models: dict[str, dict] = {}
|
|
40
|
+
self._anchor_k: int | None = None # enforced anchor count
|
|
41
|
+
|
|
42
|
+
# ── Registration ──
|
|
43
|
+
|
|
44
|
+
def set_anchors(
|
|
45
|
+
self,
|
|
46
|
+
model_name: str,
|
|
47
|
+
anchor_embeddings: np.ndarray,
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Register anchor embeddings for a model.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
model_name : identifier for this model
|
|
54
|
+
anchor_embeddings : (K, D) L2-normalized
|
|
55
|
+
"""
|
|
56
|
+
_check_l2_normalized(anchor_embeddings)
|
|
57
|
+
|
|
58
|
+
K = anchor_embeddings.shape[0]
|
|
59
|
+
if self._anchor_k is None:
|
|
60
|
+
self._anchor_k = K
|
|
61
|
+
elif K != self._anchor_k:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Anchor count mismatch: expected {self._anchor_k}, got {K} for '{model_name}'"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
sim_mean = compute_sim_mean(anchor_embeddings)
|
|
67
|
+
|
|
68
|
+
self._models[model_name] = {
|
|
69
|
+
"anchors": anchor_embeddings,
|
|
70
|
+
"sim_mean": sim_mean,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if self._verbose:
|
|
74
|
+
rec = recommend_zscore(sim_mean, self._normalize_threshold, self._normalize_harmful_threshold)
|
|
75
|
+
print(f"[RAT] {model_name} sim_mean={sim_mean:.3f} → z-score: {rec}")
|
|
76
|
+
print(
|
|
77
|
+
f"[RAT] Anchor shape: {model_name}="
|
|
78
|
+
f"({anchor_embeddings.shape[0]}, {anchor_embeddings.shape[1]})"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# ── Transform ──
|
|
82
|
+
|
|
83
|
+
def transform(
|
|
84
|
+
self,
|
|
85
|
+
model_name: str,
|
|
86
|
+
embeddings: np.ndarray,
|
|
87
|
+
role: str = "auto",
|
|
88
|
+
) -> np.ndarray:
|
|
89
|
+
"""Transform embeddings to relative representation.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
model_name : which model produced these embeddings
|
|
94
|
+
embeddings : (N, D) L2-normalized
|
|
95
|
+
role : "query" (skip z-score), "db" (apply z-score per normalize setting),
|
|
96
|
+
"auto" (treat as db)
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
(N, K) relative representation
|
|
101
|
+
"""
|
|
102
|
+
if model_name not in self._models:
|
|
103
|
+
raise RuntimeError(
|
|
104
|
+
f"Model '{model_name}' not registered. Call set_anchors() first."
|
|
105
|
+
)
|
|
106
|
+
if role not in ("query", "db", "auto"):
|
|
107
|
+
raise ValueError(f"role must be 'query', 'db', or 'auto', got '{role}'")
|
|
108
|
+
|
|
109
|
+
_check_l2_normalized(embeddings)
|
|
110
|
+
|
|
111
|
+
state = self._models[model_name]
|
|
112
|
+
anchors = state["anchors"]
|
|
113
|
+
rel = self._kernel_fn(embeddings, anchors) # (N, K)
|
|
114
|
+
|
|
115
|
+
if self._should_apply_zscore(model_name, role):
|
|
116
|
+
rel = normalize_zscore(rel)
|
|
117
|
+
|
|
118
|
+
return rel
|
|
119
|
+
|
|
120
|
+
def _should_apply_zscore(self, model_name: str, role: str) -> bool:
|
|
121
|
+
"""Decide whether to apply z-score for this model/role combination."""
|
|
122
|
+
if role == "query":
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
# role is "db" or "auto" → check normalize setting
|
|
126
|
+
state = self._models[model_name]
|
|
127
|
+
sim_mean = state["sim_mean"]
|
|
128
|
+
|
|
129
|
+
if self._normalize == "never":
|
|
130
|
+
return False
|
|
131
|
+
if self._normalize == "always":
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
# "auto": use recommendation
|
|
135
|
+
rec = recommend_zscore(sim_mean, self._normalize_threshold, self._normalize_harmful_threshold)
|
|
136
|
+
should = rec == "recommended"
|
|
137
|
+
|
|
138
|
+
if self._verbose:
|
|
139
|
+
if rec == "harmful":
|
|
140
|
+
print(
|
|
141
|
+
f"[RAT] {model_name} sim_mean={sim_mean:.3f} >= "
|
|
142
|
+
f"harmful_threshold {self._normalize_harmful_threshold} "
|
|
143
|
+
f"→ z-score: harmful (skipping)"
|
|
144
|
+
)
|
|
145
|
+
elif should:
|
|
146
|
+
print(
|
|
147
|
+
f"[RAT] {model_name} sim_mean={sim_mean:.3f} >= "
|
|
148
|
+
f"{self._normalize_threshold} → applying z-score"
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
print(
|
|
152
|
+
f"[RAT] {model_name} sim_mean={sim_mean:.3f} < "
|
|
153
|
+
f"{self._normalize_threshold} → z-score: not needed"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return should
|
|
157
|
+
|
|
158
|
+
# ── Retrieval ──
|
|
159
|
+
|
|
160
|
+
def retrieve(
|
|
161
|
+
self,
|
|
162
|
+
query_emb: np.ndarray,
|
|
163
|
+
db_emb: np.ndarray,
|
|
164
|
+
from_model: str,
|
|
165
|
+
to_model: str,
|
|
166
|
+
top_k: int = 10,
|
|
167
|
+
) -> dict:
|
|
168
|
+
"""Transform + cosine similarity retrieval.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
query_emb : (N, D_from) query embeddings from from_model
|
|
173
|
+
db_emb : (M, D_to) database embeddings from to_model
|
|
174
|
+
from_model : model name for queries
|
|
175
|
+
to_model : model name for database
|
|
176
|
+
top_k : number of results per query
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
{"indices": (N, top_k), "scores": (N, top_k)}
|
|
181
|
+
"""
|
|
182
|
+
q_rel = self.transform(from_model, query_emb, role="query")
|
|
183
|
+
d_rel = self.transform(to_model, db_emb, role="db")
|
|
184
|
+
|
|
185
|
+
# L2-normalize relative representations for cosine similarity
|
|
186
|
+
q_norm = q_rel / (np.linalg.norm(q_rel, axis=1, keepdims=True) + 1e-10)
|
|
187
|
+
d_norm = d_rel / (np.linalg.norm(d_rel, axis=1, keepdims=True) + 1e-10)
|
|
188
|
+
|
|
189
|
+
scores = q_norm @ d_norm.T # (N, M)
|
|
190
|
+
top_k = min(top_k, scores.shape[1])
|
|
191
|
+
indices = np.argsort(-scores, axis=1)[:, :top_k]
|
|
192
|
+
sorted_scores = np.take_along_axis(scores, indices, axis=1)
|
|
193
|
+
|
|
194
|
+
return {"indices": indices, "scores": sorted_scores}
|
|
195
|
+
|
|
196
|
+
# ── Compatibility ──
|
|
197
|
+
|
|
198
|
+
def estimate_compatibility(
|
|
199
|
+
self,
|
|
200
|
+
model_a: str,
|
|
201
|
+
model_b: str,
|
|
202
|
+
) -> dict:
|
|
203
|
+
"""Estimate RAT compatibility between two registered models.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
dict with sim_mean_a, sim_mean_b, max_sim_mean, estimated_recall_at_1,
|
|
208
|
+
z_score_recommendation, warnings
|
|
209
|
+
"""
|
|
210
|
+
if model_a not in self._models:
|
|
211
|
+
raise RuntimeError(f"Model '{model_a}' not registered.")
|
|
212
|
+
if model_b not in self._models:
|
|
213
|
+
raise RuntimeError(f"Model '{model_b}' not registered.")
|
|
214
|
+
|
|
215
|
+
sm_a = self._models[model_a]["sim_mean"]
|
|
216
|
+
sm_b = self._models[model_b]["sim_mean"]
|
|
217
|
+
max_sm = max(sm_a, sm_b)
|
|
218
|
+
rec = recommend_zscore(max_sm, self._normalize_threshold, self._normalize_harmful_threshold)
|
|
219
|
+
|
|
220
|
+
warn_list: list[str] = []
|
|
221
|
+
if max_sm > self._normalize_harmful_threshold:
|
|
222
|
+
warn_list.append(
|
|
223
|
+
f"High sim_mean ({max_sm:.3f}) detected. "
|
|
224
|
+
"Z-score normalization may be harmful for this pair."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return {
|
|
228
|
+
"sim_mean_a": sm_a,
|
|
229
|
+
"sim_mean_b": sm_b,
|
|
230
|
+
"max_sim_mean": max_sm,
|
|
231
|
+
"estimated_recall_at_1": None, # Phase 2
|
|
232
|
+
"z_score_recommendation": rec,
|
|
233
|
+
"warnings": warn_list,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# ── Persistence ──
|
|
237
|
+
|
|
238
|
+
def save(self, path: str) -> None:
|
|
239
|
+
"""Save hub state to .npz file."""
|
|
240
|
+
save_dict: dict[str, object] = {}
|
|
241
|
+
|
|
242
|
+
model_names = list(self._models.keys())
|
|
243
|
+
for name in model_names:
|
|
244
|
+
state = self._models[name]
|
|
245
|
+
save_dict[f"anchor_{name}"] = state["anchors"]
|
|
246
|
+
|
|
247
|
+
config = {
|
|
248
|
+
"kernel": self._kernel_name,
|
|
249
|
+
"kernel_params": self._kernel_params,
|
|
250
|
+
"normalize": self._normalize,
|
|
251
|
+
"normalize_threshold": self._normalize_threshold,
|
|
252
|
+
"normalize_harmful_threshold": self._normalize_harmful_threshold,
|
|
253
|
+
"model_names": model_names,
|
|
254
|
+
"sim_means": {n: self._models[n]["sim_mean"] for n in model_names},
|
|
255
|
+
"version": "0.1.0",
|
|
256
|
+
}
|
|
257
|
+
save_dict["config"] = np.array(json.dumps(config))
|
|
258
|
+
|
|
259
|
+
np.savez_compressed(path, **save_dict)
|
|
260
|
+
|
|
261
|
+
@classmethod
|
|
262
|
+
def load(cls, path: str) -> "RATHub":
|
|
263
|
+
"""Load hub state from .npz file."""
|
|
264
|
+
if not path.endswith(".npz"):
|
|
265
|
+
path = path + ".npz"
|
|
266
|
+
data = np.load(path, allow_pickle=False)
|
|
267
|
+
config = json.loads(str(data["config"]))
|
|
268
|
+
|
|
269
|
+
hub = cls(
|
|
270
|
+
kernel=config["kernel"],
|
|
271
|
+
kernel_params=config.get("kernel_params"),
|
|
272
|
+
normalize=config["normalize"],
|
|
273
|
+
normalize_threshold=config["normalize_threshold"],
|
|
274
|
+
normalize_harmful_threshold=config.get("normalize_harmful_threshold", 0.65),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
for name in config["model_names"]:
|
|
278
|
+
anchors = data[f"anchor_{name}"]
|
|
279
|
+
hub.set_anchors(name, anchors)
|
|
280
|
+
# Override sim_mean from saved config (avoids recomputation drift)
|
|
281
|
+
hub._models[name]["sim_mean"] = config["sim_means"][name]
|
|
282
|
+
|
|
283
|
+
return hub
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _check_l2_normalized(embeddings: np.ndarray) -> None:
|
|
287
|
+
"""Warn if embeddings don't appear L2-normalized."""
|
|
288
|
+
norms = np.linalg.norm(embeddings, axis=1)
|
|
289
|
+
if not np.allclose(norms, 1.0, atol=0.01):
|
|
290
|
+
warnings.warn(
|
|
291
|
+
"Input embeddings may not be L2-normalized. RAT expects normalized embeddings.",
|
|
292
|
+
stacklevel=3,
|
|
293
|
+
)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Kernel functions for computing relative representations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def poly_kernel(
|
|
11
|
+
X: np.ndarray,
|
|
12
|
+
A: np.ndarray,
|
|
13
|
+
degree: int = 2,
|
|
14
|
+
coef0: float = 1.0,
|
|
15
|
+
) -> np.ndarray:
|
|
16
|
+
"""Polynomial kernel: (X @ A.T + coef0) ** degree.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
X : (N, D) input embeddings
|
|
21
|
+
A : (K, D) anchor embeddings
|
|
22
|
+
degree : polynomial degree
|
|
23
|
+
coef0 : constant term
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
(N, K) kernel matrix
|
|
28
|
+
"""
|
|
29
|
+
return (X @ A.T + coef0) ** degree
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def cosine_kernel(
|
|
33
|
+
X: np.ndarray,
|
|
34
|
+
A: np.ndarray,
|
|
35
|
+
) -> np.ndarray:
|
|
36
|
+
"""Cosine similarity kernel: X @ A.T (assumes L2-normalized inputs).
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
X : (N, D) input embeddings, L2-normalized
|
|
41
|
+
A : (K, D) anchor embeddings, L2-normalized
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
(N, K) similarity matrix
|
|
46
|
+
"""
|
|
47
|
+
return X @ A.T
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def rbf_kernel(
|
|
51
|
+
X: np.ndarray,
|
|
52
|
+
A: np.ndarray,
|
|
53
|
+
gamma: float | None = None,
|
|
54
|
+
) -> np.ndarray:
|
|
55
|
+
"""RBF (Gaussian) kernel: exp(-gamma * ||x - a||^2).
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
X : (N, D) input embeddings
|
|
60
|
+
A : (K, D) anchor embeddings
|
|
61
|
+
gamma : kernel width. None uses 1/D
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
(N, K) kernel matrix
|
|
66
|
+
"""
|
|
67
|
+
if gamma is None:
|
|
68
|
+
gamma = 1.0 / X.shape[1]
|
|
69
|
+
# ||x - a||^2 = ||x||^2 + ||a||^2 - 2 x·a
|
|
70
|
+
X_sq = np.sum(X ** 2, axis=1, keepdims=True) # (N, 1)
|
|
71
|
+
A_sq = np.sum(A ** 2, axis=1, keepdims=True) # (K, 1)
|
|
72
|
+
dist_sq = X_sq + A_sq.T - 2.0 * (X @ A.T) # (N, K)
|
|
73
|
+
dist_sq = np.maximum(dist_sq, 0.0) # numerical safety
|
|
74
|
+
return np.exp(-gamma * dist_sq)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_kernel(name: str, params: dict | None = None):
|
|
78
|
+
"""Factory: return a callable(X, A) -> np.ndarray for the named kernel.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
name : "poly", "cosine", or "rbf"
|
|
83
|
+
params : optional kwargs forwarded to the kernel function
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
callable(X, A) -> np.ndarray
|
|
88
|
+
"""
|
|
89
|
+
params = params or {}
|
|
90
|
+
kernels = {
|
|
91
|
+
"poly": poly_kernel,
|
|
92
|
+
"cosine": cosine_kernel,
|
|
93
|
+
"rbf": rbf_kernel,
|
|
94
|
+
}
|
|
95
|
+
if name not in kernels:
|
|
96
|
+
raise ValueError(f"Unknown kernel '{name}'. Choose from: {list(kernels.keys())}")
|
|
97
|
+
return partial(kernels[name], **params)
|