vector-engine 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vector_engine/__init__.py +16 -0
- vector_engine/array.py +153 -0
- vector_engine/backends/__init__.py +13 -0
- vector_engine/backends/base.py +28 -0
- vector_engine/backends/bruteforce.py +107 -0
- vector_engine/backends/faiss_backend.py +123 -0
- vector_engine/backends/registry.py +15 -0
- vector_engine/eval/__init__.py +17 -0
- vector_engine/eval/retrieval.py +173 -0
- vector_engine/index.py +190 -0
- vector_engine/io/__init__.py +3 -0
- vector_engine/io/manifest.py +50 -0
- vector_engine/metric.py +58 -0
- vector_engine/ml/__init__.py +4 -0
- vector_engine/ml/clustering.py +56 -0
- vector_engine/ml/knn.py +71 -0
- vector_engine/results.py +15 -0
- vector_engine/training/__init__.py +3 -0
- vector_engine/training/hard_negative.py +140 -0
- vector_engine-1.0.0.dist-info/METADATA +342 -0
- vector_engine-1.0.0.dist-info/RECORD +24 -0
- vector_engine-1.0.0.dist-info/WHEEL +5 -0
- vector_engine-1.0.0.dist-info/licenses/LICENSE +21 -0
- vector_engine-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Vector Engine public API."""
|
|
2
|
+
|
|
3
|
+
__version__ = "1.0.0"
|
|
4
|
+
|
|
5
|
+
from .array import VectorArray
|
|
6
|
+
from .index import VectorIndex
|
|
7
|
+
from .metric import Metric
|
|
8
|
+
from .results import SearchResult
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"__version__",
|
|
12
|
+
"Metric",
|
|
13
|
+
"SearchResult",
|
|
14
|
+
"VectorArray",
|
|
15
|
+
"VectorIndex",
|
|
16
|
+
]
|
vector_engine/array.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Mapping, Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _normalize_l2(x: np.ndarray) -> np.ndarray:
|
|
10
|
+
norms = np.linalg.norm(x, axis=1, keepdims=True)
|
|
11
|
+
norms = np.where(norms == 0, 1.0, norms)
|
|
12
|
+
return x / norms
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class VectorArray:
|
|
17
|
+
"""Canonical vector container with IDs and metadata."""
|
|
18
|
+
|
|
19
|
+
values: np.ndarray
|
|
20
|
+
ids: np.ndarray
|
|
21
|
+
metadata: list[dict[str, Any]] | None = None
|
|
22
|
+
source_framework: str = "numpy"
|
|
23
|
+
source_device: str = "cpu"
|
|
24
|
+
_id_to_row: dict[int | str, int] | None = None
|
|
25
|
+
|
|
26
|
+
def __post_init__(self) -> None:
|
|
27
|
+
values = np.ascontiguousarray(self.values, dtype=np.float32)
|
|
28
|
+
if values.ndim != 2:
|
|
29
|
+
raise ValueError("vector_array_error: values must be a 2D array with shape (n, d)")
|
|
30
|
+
if values.shape[0] == 0 or values.shape[1] == 0:
|
|
31
|
+
raise ValueError("vector_array_error: values must have at least one row and one column")
|
|
32
|
+
if len(self.ids) != values.shape[0]:
|
|
33
|
+
raise ValueError("vector_array_error: ids length must match number of rows")
|
|
34
|
+
if self.metadata is not None and len(self.metadata) != values.shape[0]:
|
|
35
|
+
raise ValueError("vector_array_error: metadata length must match number of rows")
|
|
36
|
+
id_to_row: dict[int | str, int] = {}
|
|
37
|
+
for i, raw in enumerate(self.ids.tolist()):
|
|
38
|
+
if not isinstance(raw, (int, str, np.integer)):
|
|
39
|
+
raise TypeError("vector_array_error: ids must contain only int or str values")
|
|
40
|
+
if isinstance(raw, np.integer):
|
|
41
|
+
raw = int(raw)
|
|
42
|
+
if raw in id_to_row:
|
|
43
|
+
raise ValueError(f"vector_array_error: duplicate id found: {raw}")
|
|
44
|
+
id_to_row[raw] = i
|
|
45
|
+
|
|
46
|
+
object.__setattr__(self, "values", values)
|
|
47
|
+
object.__setattr__(self, "_id_to_row", id_to_row)
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def from_numpy(
|
|
51
|
+
cls,
|
|
52
|
+
x: np.ndarray,
|
|
53
|
+
*,
|
|
54
|
+
ids: Sequence[int | str] | None = None,
|
|
55
|
+
metadata: Sequence[Mapping[str, Any]] | None = None,
|
|
56
|
+
normalize: bool = False,
|
|
57
|
+
) -> "VectorArray":
|
|
58
|
+
arr = np.asarray(x, dtype=np.float32)
|
|
59
|
+
if arr.ndim != 2:
|
|
60
|
+
raise ValueError("vector_array_error: input numpy array must be shape (n, d)")
|
|
61
|
+
arr = np.ascontiguousarray(arr)
|
|
62
|
+
if normalize:
|
|
63
|
+
arr = _normalize_l2(arr)
|
|
64
|
+
if ids is None:
|
|
65
|
+
ids_array = np.arange(arr.shape[0], dtype=np.int64)
|
|
66
|
+
else:
|
|
67
|
+
ids_list = list(ids)
|
|
68
|
+
if len(ids_list) != arr.shape[0]:
|
|
69
|
+
raise ValueError("vector_array_error: ids length must match number of rows")
|
|
70
|
+
ids_array = np.asarray(ids_list, dtype=object)
|
|
71
|
+
md = [dict(item) for item in metadata] if metadata is not None else None
|
|
72
|
+
return cls(values=arr, ids=ids_array, metadata=md, source_framework="numpy")
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def from_torch(
|
|
76
|
+
cls,
|
|
77
|
+
x: "torch.Tensor", # type: ignore[name-defined]
|
|
78
|
+
*,
|
|
79
|
+
ids: Sequence[int | str] | None = None,
|
|
80
|
+
metadata: Sequence[Mapping[str, Any]] | None = None,
|
|
81
|
+
to_numpy: bool = True,
|
|
82
|
+
normalize: bool = False,
|
|
83
|
+
) -> "VectorArray":
|
|
84
|
+
if not to_numpy:
|
|
85
|
+
raise ValueError("vector_array_error: v0.1 requires to_numpy=True for torch inputs")
|
|
86
|
+
import torch
|
|
87
|
+
|
|
88
|
+
if not isinstance(x, torch.Tensor):
|
|
89
|
+
raise TypeError("vector_array_error: x must be a torch.Tensor")
|
|
90
|
+
device = str(x.device)
|
|
91
|
+
arr = x.detach().to("cpu").numpy()
|
|
92
|
+
out = cls.from_numpy(arr, ids=ids, metadata=metadata, normalize=normalize)
|
|
93
|
+
return cls(
|
|
94
|
+
values=out.values,
|
|
95
|
+
ids=out.ids,
|
|
96
|
+
metadata=out.metadata,
|
|
97
|
+
source_framework="torch",
|
|
98
|
+
source_device=device,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def from_jax(
|
|
103
|
+
cls,
|
|
104
|
+
x: "jax.Array", # type: ignore[name-defined]
|
|
105
|
+
*,
|
|
106
|
+
ids: Sequence[int | str] | None = None,
|
|
107
|
+
metadata: Sequence[Mapping[str, Any]] | None = None,
|
|
108
|
+
to_numpy: bool = True,
|
|
109
|
+
normalize: bool = False,
|
|
110
|
+
) -> "VectorArray":
|
|
111
|
+
if not to_numpy:
|
|
112
|
+
raise ValueError("vector_array_error: v0.1 requires to_numpy=True for jax inputs")
|
|
113
|
+
import jax.numpy as jnp
|
|
114
|
+
|
|
115
|
+
arr = np.asarray(jnp.asarray(x), dtype=np.float32)
|
|
116
|
+
out = cls.from_numpy(arr, ids=ids, metadata=metadata, normalize=normalize)
|
|
117
|
+
return cls(
|
|
118
|
+
values=out.values,
|
|
119
|
+
ids=out.ids,
|
|
120
|
+
metadata=out.metadata,
|
|
121
|
+
source_framework="jax",
|
|
122
|
+
source_device="device",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def to_numpy(self, *, dtype: np.dtype = np.float32, copy: bool = False) -> np.ndarray:
|
|
126
|
+
arr = self.values.astype(dtype, copy=False)
|
|
127
|
+
if copy:
|
|
128
|
+
return np.array(arr, copy=True)
|
|
129
|
+
return arr
|
|
130
|
+
|
|
131
|
+
def subset(self, ids: Sequence[int | str]) -> "VectorArray":
|
|
132
|
+
assert self._id_to_row is not None
|
|
133
|
+
rows = []
|
|
134
|
+
for id_ in ids:
|
|
135
|
+
if id_ not in self._id_to_row:
|
|
136
|
+
raise KeyError(f"vector_array_error: unknown id in subset(): {id_}")
|
|
137
|
+
rows.append(self._id_to_row[id_])
|
|
138
|
+
values = self.values[rows]
|
|
139
|
+
out_ids = np.asarray(list(ids), dtype=object)
|
|
140
|
+
out_meta = None
|
|
141
|
+
if self.metadata is not None:
|
|
142
|
+
out_meta = [self.metadata[i] for i in rows]
|
|
143
|
+
return VectorArray(
|
|
144
|
+
values=values,
|
|
145
|
+
ids=out_ids,
|
|
146
|
+
metadata=out_meta,
|
|
147
|
+
source_framework=self.source_framework,
|
|
148
|
+
source_device=self.source_device,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def shape(self) -> tuple[int, int]:
|
|
153
|
+
return self.values.shape
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .bruteforce import BruteForceBackend
|
|
2
|
+
from .faiss_backend import FaissBackend
|
|
3
|
+
from .registry import get_backend, register_backend
|
|
4
|
+
|
|
5
|
+
register_backend("bruteforce", BruteForceBackend)
|
|
6
|
+
register_backend("faiss", FaissBackend)
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BruteForceBackend",
|
|
10
|
+
"FaissBackend",
|
|
11
|
+
"get_backend",
|
|
12
|
+
"register_backend",
|
|
13
|
+
]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from vector_engine.metric import Metric
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseBackend(Protocol):
|
|
11
|
+
name: str
|
|
12
|
+
capabilities: dict[str, bool]
|
|
13
|
+
|
|
14
|
+
def build(self, xb: np.ndarray, metric: Metric, config: dict) -> None:
|
|
15
|
+
...
|
|
16
|
+
|
|
17
|
+
def add(self, xb: np.ndarray) -> np.ndarray:
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def search(self, xq: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
def save(self, path: str) -> None:
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def load(cls, path: str) -> "BaseBackend":
|
|
28
|
+
...
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from vector_engine.metric import Metric
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _normalize_l2(x: np.ndarray) -> np.ndarray:
|
|
13
|
+
norms = np.linalg.norm(x, axis=1, keepdims=True)
|
|
14
|
+
norms = np.where(norms == 0, 1.0, norms)
|
|
15
|
+
return x / norms
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class BruteForceBackend:
|
|
20
|
+
name: str = "bruteforce"
|
|
21
|
+
capabilities: dict[str, bool] = field(
|
|
22
|
+
default_factory=lambda: {
|
|
23
|
+
"supports_delete": False,
|
|
24
|
+
"supports_custom_metric": True,
|
|
25
|
+
"supports_persistence": True,
|
|
26
|
+
}
|
|
27
|
+
)
|
|
28
|
+
xb: np.ndarray | None = None
|
|
29
|
+
metric: Metric | None = None
|
|
30
|
+
|
|
31
|
+
def build(self, xb: np.ndarray, metric: Metric, config: dict) -> None:
|
|
32
|
+
arr = np.ascontiguousarray(xb, dtype=np.float32)
|
|
33
|
+
self.metric = metric
|
|
34
|
+
if metric.name == "cosine":
|
|
35
|
+
arr = _normalize_l2(arr)
|
|
36
|
+
self.xb = arr
|
|
37
|
+
|
|
38
|
+
def add(self, xb: np.ndarray) -> np.ndarray:
|
|
39
|
+
if self.xb is None or self.metric is None:
|
|
40
|
+
raise RuntimeError("backend is not built")
|
|
41
|
+
arr = np.ascontiguousarray(xb, dtype=np.float32)
|
|
42
|
+
if self.metric.name == "cosine":
|
|
43
|
+
arr = _normalize_l2(arr)
|
|
44
|
+
start = self.xb.shape[0]
|
|
45
|
+
self.xb = np.vstack([self.xb, arr])
|
|
46
|
+
return np.arange(start, start + arr.shape[0], dtype=np.int64)
|
|
47
|
+
|
|
48
|
+
def search(self, xq: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
|
49
|
+
if self.xb is None or self.metric is None:
|
|
50
|
+
raise RuntimeError("backend is not built")
|
|
51
|
+
queries = np.ascontiguousarray(xq, dtype=np.float32)
|
|
52
|
+
if self.metric.name == "cosine":
|
|
53
|
+
queries = _normalize_l2(queries)
|
|
54
|
+
|
|
55
|
+
if self.metric.fn is not None:
|
|
56
|
+
# Metric fn returns pairwise matrix, shape (n_queries, n_db).
|
|
57
|
+
scores = self.metric.fn(queries, self.xb)
|
|
58
|
+
elif self.metric.name == "ip" or self.metric.name == "cosine":
|
|
59
|
+
scores = queries @ self.xb.T
|
|
60
|
+
elif self.metric.name == "l2":
|
|
61
|
+
diff = queries[:, None, :] - self.xb[None, :, :]
|
|
62
|
+
scores = np.sum(diff * diff, axis=2)
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(f"unsupported metric for brute force: {self.metric.name}")
|
|
65
|
+
|
|
66
|
+
k = min(k, self.xb.shape[0])
|
|
67
|
+
if self.metric.higher_is_better:
|
|
68
|
+
idx = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
|
|
69
|
+
row = np.arange(scores.shape[0])[:, None]
|
|
70
|
+
val = scores[row, idx]
|
|
71
|
+
order = np.argsort(-val, axis=1)
|
|
72
|
+
else:
|
|
73
|
+
idx = np.argpartition(scores, kth=k - 1, axis=1)[:, :k]
|
|
74
|
+
row = np.arange(scores.shape[0])[:, None]
|
|
75
|
+
val = scores[row, idx]
|
|
76
|
+
order = np.argsort(val, axis=1)
|
|
77
|
+
sorted_idx = np.take_along_axis(idx, order, axis=1)
|
|
78
|
+
sorted_scores = np.take_along_axis(scores, sorted_idx, axis=1)
|
|
79
|
+
return sorted_scores, sorted_idx.astype(np.int64)
|
|
80
|
+
|
|
81
|
+
def save(self, path: str) -> None:
|
|
82
|
+
if self.xb is None or self.metric is None:
|
|
83
|
+
raise RuntimeError("backend is not built")
|
|
84
|
+
os.makedirs(path, exist_ok=True)
|
|
85
|
+
np.save(os.path.join(path, "vectors.npy"), self.xb)
|
|
86
|
+
with open(os.path.join(path, "backend_meta.json"), "w", encoding="utf-8") as f:
|
|
87
|
+
json.dump(
|
|
88
|
+
{
|
|
89
|
+
"name": self.name,
|
|
90
|
+
"metric_name": self.metric.name,
|
|
91
|
+
"higher_is_better": self.metric.higher_is_better,
|
|
92
|
+
"has_custom_metric": self.metric.fn is not None,
|
|
93
|
+
},
|
|
94
|
+
f,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def load(cls, path: str) -> "BruteForceBackend":
|
|
99
|
+
with open(os.path.join(path, "backend_meta.json"), "r", encoding="utf-8") as f:
|
|
100
|
+
meta = json.load(f)
|
|
101
|
+
if meta["has_custom_metric"]:
|
|
102
|
+
raise ValueError("cannot load custom brute force metric without explicit code hook")
|
|
103
|
+
xb = np.load(os.path.join(path, "vectors.npy"))
|
|
104
|
+
backend = cls()
|
|
105
|
+
backend.metric = Metric(name=meta["metric_name"], higher_is_better=meta["higher_is_better"])
|
|
106
|
+
backend.xb = xb
|
|
107
|
+
return backend
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from vector_engine.metric import Metric
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _normalize_l2(x: np.ndarray) -> np.ndarray:
|
|
13
|
+
norms = np.linalg.norm(x, axis=1, keepdims=True)
|
|
14
|
+
norms = np.where(norms == 0, 1.0, norms)
|
|
15
|
+
return x / norms
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class FaissBackend:
|
|
20
|
+
name: str = "faiss"
|
|
21
|
+
capabilities: dict[str, bool] = field(
|
|
22
|
+
default_factory=lambda: {
|
|
23
|
+
"supports_delete": False,
|
|
24
|
+
"supports_custom_metric": False,
|
|
25
|
+
"supports_persistence": True,
|
|
26
|
+
}
|
|
27
|
+
)
|
|
28
|
+
_index: object | None = None
|
|
29
|
+
_metric: Metric | None = None
|
|
30
|
+
_count: int = 0
|
|
31
|
+
_config: dict | None = None
|
|
32
|
+
|
|
33
|
+
def _faiss(self) -> object:
|
|
34
|
+
try:
|
|
35
|
+
import faiss # type: ignore
|
|
36
|
+
except Exception as exc: # pragma: no cover - import path depends on env
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"Faiss backend requested but faiss is not installed. "
|
|
39
|
+
"Install with `pip install faiss-cpu`."
|
|
40
|
+
) from exc
|
|
41
|
+
return faiss
|
|
42
|
+
|
|
43
|
+
def build(self, xb: np.ndarray, metric: Metric, config: dict) -> None:
|
|
44
|
+
if metric.fn is not None:
|
|
45
|
+
raise ValueError("faiss backend does not support custom metric functions")
|
|
46
|
+
faiss = self._faiss()
|
|
47
|
+
arr = np.ascontiguousarray(xb, dtype=np.float32)
|
|
48
|
+
if metric.name == "cosine":
|
|
49
|
+
arr = _normalize_l2(arr)
|
|
50
|
+
metric_for_faiss = "ip"
|
|
51
|
+
else:
|
|
52
|
+
metric_for_faiss = metric.name
|
|
53
|
+
|
|
54
|
+
index_factory = config.get("index_factory", "Flat")
|
|
55
|
+
if metric_for_faiss == "l2":
|
|
56
|
+
index = faiss.index_factory(arr.shape[1], index_factory, faiss.METRIC_L2)
|
|
57
|
+
elif metric_for_faiss == "ip":
|
|
58
|
+
index = faiss.index_factory(arr.shape[1], index_factory, faiss.METRIC_INNER_PRODUCT)
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"unsupported metric for faiss: {metric.name}")
|
|
61
|
+
|
|
62
|
+
if hasattr(index, "train") and not index.is_trained:
|
|
63
|
+
index.train(arr)
|
|
64
|
+
index.add(arr)
|
|
65
|
+
|
|
66
|
+
nprobe = config.get("nprobe")
|
|
67
|
+
if nprobe is not None and hasattr(index, "nprobe"):
|
|
68
|
+
index.nprobe = int(nprobe)
|
|
69
|
+
|
|
70
|
+
self._index = index
|
|
71
|
+
self._metric = metric
|
|
72
|
+
self._count = arr.shape[0]
|
|
73
|
+
self._config = dict(config)
|
|
74
|
+
|
|
75
|
+
def add(self, xb: np.ndarray) -> np.ndarray:
|
|
76
|
+
if self._index is None or self._metric is None:
|
|
77
|
+
raise RuntimeError("backend is not built")
|
|
78
|
+
arr = np.ascontiguousarray(xb, dtype=np.float32)
|
|
79
|
+
if self._metric.name == "cosine":
|
|
80
|
+
arr = _normalize_l2(arr)
|
|
81
|
+
start = self._count
|
|
82
|
+
self._index.add(arr)
|
|
83
|
+
self._count += arr.shape[0]
|
|
84
|
+
return np.arange(start, self._count, dtype=np.int64)
|
|
85
|
+
|
|
86
|
+
def search(self, xq: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
|
87
|
+
if self._index is None or self._metric is None:
|
|
88
|
+
raise RuntimeError("backend is not built")
|
|
89
|
+
queries = np.ascontiguousarray(xq, dtype=np.float32)
|
|
90
|
+
if self._metric.name == "cosine":
|
|
91
|
+
queries = _normalize_l2(queries)
|
|
92
|
+
scores, internal_ids = self._index.search(queries, k)
|
|
93
|
+
return scores, internal_ids.astype(np.int64)
|
|
94
|
+
|
|
95
|
+
def save(self, path: str) -> None:
|
|
96
|
+
if self._index is None or self._metric is None:
|
|
97
|
+
raise RuntimeError("backend is not built")
|
|
98
|
+
faiss = self._faiss()
|
|
99
|
+
os.makedirs(path, exist_ok=True)
|
|
100
|
+
faiss.write_index(self._index, os.path.join(path, "faiss.index"))
|
|
101
|
+
with open(os.path.join(path, "backend_meta.json"), "w", encoding="utf-8") as f:
|
|
102
|
+
json.dump(
|
|
103
|
+
{
|
|
104
|
+
"name": self.name,
|
|
105
|
+
"metric_name": self._metric.name,
|
|
106
|
+
"higher_is_better": self._metric.higher_is_better,
|
|
107
|
+
"config": self._config or {},
|
|
108
|
+
"count": self._count,
|
|
109
|
+
},
|
|
110
|
+
f,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def load(cls, path: str) -> "FaissBackend":
|
|
115
|
+
backend = cls()
|
|
116
|
+
faiss = backend._faiss()
|
|
117
|
+
with open(os.path.join(path, "backend_meta.json"), "r", encoding="utf-8") as f:
|
|
118
|
+
meta = json.load(f)
|
|
119
|
+
backend._index = faiss.read_index(os.path.join(path, "faiss.index"))
|
|
120
|
+
backend._metric = Metric(name=meta["metric_name"], higher_is_better=meta["higher_is_better"])
|
|
121
|
+
backend._config = meta.get("config", {})
|
|
122
|
+
backend._count = int(meta.get("count", 0))
|
|
123
|
+
return backend
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
BACKENDS: dict[str, type[Any]] = {}
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def register_backend(name: str, backend_cls: type[Any]) -> None:
|
|
9
|
+
BACKENDS[name] = backend_cls
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_backend(name: str) -> type[Any]:
|
|
13
|
+
if name not in BACKENDS:
|
|
14
|
+
raise ValueError(f"unknown backend: {name}")
|
|
15
|
+
return BACKENDS[name]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .retrieval import (
|
|
2
|
+
batch_metrics_summary,
|
|
3
|
+
ndcg_at_k,
|
|
4
|
+
precision_at_k,
|
|
5
|
+
recall_at_k,
|
|
6
|
+
retrieval_report,
|
|
7
|
+
retrieval_report_detailed,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"batch_metrics_summary",
|
|
12
|
+
"ndcg_at_k",
|
|
13
|
+
"precision_at_k",
|
|
14
|
+
"recall_at_k",
|
|
15
|
+
"retrieval_report",
|
|
16
|
+
"retrieval_report_detailed",
|
|
17
|
+
]
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Iterable, Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _normalize_ground_truth(
|
|
10
|
+
ground_truth_ids: np.ndarray | Iterable[Iterable[object]],
|
|
11
|
+
n_queries: int,
|
|
12
|
+
) -> list[set[object]]:
|
|
13
|
+
if isinstance(ground_truth_ids, np.ndarray):
|
|
14
|
+
if ground_truth_ids.ndim == 2:
|
|
15
|
+
rows = [ground_truth_ids[i].tolist() for i in range(ground_truth_ids.shape[0])]
|
|
16
|
+
elif ground_truth_ids.ndim == 1:
|
|
17
|
+
rows = ground_truth_ids.tolist()
|
|
18
|
+
else:
|
|
19
|
+
raise ValueError("eval_error: ground_truth_ids must be 1D or 2D")
|
|
20
|
+
else:
|
|
21
|
+
rows = list(ground_truth_ids)
|
|
22
|
+
|
|
23
|
+
if len(rows) != n_queries:
|
|
24
|
+
raise ValueError("eval_error: retrieved_ids and ground_truth_ids must have same number of queries")
|
|
25
|
+
|
|
26
|
+
normalized: list[set[object]] = []
|
|
27
|
+
for i, row in enumerate(rows):
|
|
28
|
+
if isinstance(row, (str, bytes)) or not isinstance(row, Iterable):
|
|
29
|
+
raise TypeError(f"eval_error: ground_truth_ids row {i} must be an iterable of IDs")
|
|
30
|
+
normalized.append(set(row))
|
|
31
|
+
return normalized
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _ensure_2d(name: str, value: np.ndarray) -> np.ndarray:
|
|
35
|
+
arr = np.asarray(value, dtype=object)
|
|
36
|
+
if arr.ndim != 2:
|
|
37
|
+
raise ValueError(f"eval_error: {name} must be a 2D array")
|
|
38
|
+
if arr.shape[0] == 0:
|
|
39
|
+
raise ValueError(f"eval_error: {name} cannot be empty")
|
|
40
|
+
return arr
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _validate_inputs(
|
|
44
|
+
retrieved_ids: np.ndarray,
|
|
45
|
+
ground_truth_ids: np.ndarray | Iterable[Iterable[object]],
|
|
46
|
+
k: int,
|
|
47
|
+
) -> tuple[np.ndarray, list[set[object]]]:
|
|
48
|
+
if not isinstance(k, int):
|
|
49
|
+
raise TypeError("eval_error: k must be an int")
|
|
50
|
+
if k <= 0:
|
|
51
|
+
raise ValueError("eval_error: k must be > 0")
|
|
52
|
+
retrieved = _ensure_2d("retrieved_ids", retrieved_ids)
|
|
53
|
+
if k > retrieved.shape[1]:
|
|
54
|
+
raise ValueError("eval_error: k cannot exceed retrieved_ids width")
|
|
55
|
+
return retrieved, _normalize_ground_truth(ground_truth_ids, n_queries=retrieved.shape[0])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _per_query_precision(retrieved_row: Sequence[object], gt_row: set[object], k: int) -> float:
|
|
59
|
+
hit = sum(1 for item in list(retrieved_row)[:k] if item in gt_row)
|
|
60
|
+
return hit / float(k)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _per_query_recall(retrieved_row: Sequence[object], gt_row: set[object], k: int) -> float:
|
|
64
|
+
if len(gt_row) == 0:
|
|
65
|
+
return 0.0
|
|
66
|
+
hit = sum(1 for item in list(retrieved_row)[:k] if item in gt_row)
|
|
67
|
+
return hit / float(len(gt_row))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _per_query_ndcg(retrieved_row: Sequence[object], gt_row: set[object], k: int) -> float:
|
|
71
|
+
dcg = 0.0
|
|
72
|
+
for rank, item in enumerate(list(retrieved_row)[:k], start=1):
|
|
73
|
+
rel = 1.0 if item in gt_row else 0.0
|
|
74
|
+
dcg += rel / math.log2(rank + 1)
|
|
75
|
+
ideal_hits = min(k, len(gt_row))
|
|
76
|
+
idcg = sum(1.0 / math.log2(rank + 1) for rank in range(1, ideal_hits + 1))
|
|
77
|
+
return 0.0 if idcg == 0 else dcg / idcg
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def precision_at_k(
|
|
81
|
+
retrieved_ids: np.ndarray,
|
|
82
|
+
ground_truth_ids: np.ndarray | Iterable[Iterable[object]],
|
|
83
|
+
k: int,
|
|
84
|
+
) -> float:
|
|
85
|
+
retrieved, gt = _validate_inputs(retrieved_ids, ground_truth_ids, k)
|
|
86
|
+
vals = [_per_query_precision(retrieved[i].tolist(), gt[i], k) for i in range(retrieved.shape[0])]
|
|
87
|
+
return float(np.mean(np.asarray(vals, dtype=np.float64)))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def recall_at_k(
|
|
91
|
+
retrieved_ids: np.ndarray,
|
|
92
|
+
ground_truth_ids: np.ndarray | Iterable[Iterable[object]],
|
|
93
|
+
k: int,
|
|
94
|
+
) -> float:
|
|
95
|
+
retrieved, gt = _validate_inputs(retrieved_ids, ground_truth_ids, k)
|
|
96
|
+
vals = [_per_query_recall(retrieved[i].tolist(), gt[i], k) for i in range(retrieved.shape[0])]
|
|
97
|
+
return float(np.mean(np.asarray(vals, dtype=np.float64)))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def ndcg_at_k(
|
|
101
|
+
retrieved_ids: np.ndarray,
|
|
102
|
+
ground_truth_ids: np.ndarray | Iterable[Iterable[object]],
|
|
103
|
+
k: int,
|
|
104
|
+
) -> float:
|
|
105
|
+
retrieved, gt = _validate_inputs(retrieved_ids, ground_truth_ids, k)
|
|
106
|
+
vals = [_per_query_ndcg(retrieved[i].tolist(), gt[i], k) for i in range(retrieved.shape[0])]
|
|
107
|
+
return float(np.mean(np.asarray(vals, dtype=np.float64)))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def retrieval_report(
|
|
111
|
+
retrieved_ids: np.ndarray,
|
|
112
|
+
ground_truth_ids: np.ndarray | Iterable[Iterable[object]],
|
|
113
|
+
ks: Sequence[int] = (1, 5, 10),
|
|
114
|
+
) -> dict[str, float]:
|
|
115
|
+
"""Compute retrieval metrics for multiple k values."""
|
|
116
|
+
if len(ks) == 0:
|
|
117
|
+
raise ValueError("eval_error: ks cannot be empty")
|
|
118
|
+
report: dict[str, float] = {}
|
|
119
|
+
for k in ks:
|
|
120
|
+
if k <= 0:
|
|
121
|
+
raise ValueError("eval_error: each k must be > 0")
|
|
122
|
+
report[f"precision@{k}"] = precision_at_k(retrieved_ids, ground_truth_ids, k)
|
|
123
|
+
report[f"recall@{k}"] = recall_at_k(retrieved_ids, ground_truth_ids, k)
|
|
124
|
+
report[f"ndcg@{k}"] = ndcg_at_k(retrieved_ids, ground_truth_ids, k)
|
|
125
|
+
return report
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def retrieval_report_detailed(
|
|
129
|
+
retrieved_ids: np.ndarray,
|
|
130
|
+
ground_truth_ids: np.ndarray | Iterable[Iterable[object]],
|
|
131
|
+
ks: Sequence[int] = (1, 5, 10),
|
|
132
|
+
*,
|
|
133
|
+
include_per_query: bool = True,
|
|
134
|
+
) -> dict[str, object]:
|
|
135
|
+
"""Compute aggregate and per-query retrieval metrics."""
|
|
136
|
+
retrieved = _ensure_2d("retrieved_ids", retrieved_ids)
|
|
137
|
+
summary = retrieval_report(retrieved_ids, ground_truth_ids, ks=ks)
|
|
138
|
+
payload: dict[str, object] = {"summary": summary}
|
|
139
|
+
if include_per_query:
|
|
140
|
+
per_query: list[dict[str, float]] = []
|
|
141
|
+
gt = _normalize_ground_truth(ground_truth_ids, n_queries=retrieved.shape[0])
|
|
142
|
+
for i in range(retrieved.shape[0]):
|
|
143
|
+
row: dict[str, float] = {}
|
|
144
|
+
for k in ks:
|
|
145
|
+
row[f"precision@{k}"] = _per_query_precision(retrieved[i].tolist(), gt[i], k)
|
|
146
|
+
row[f"recall@{k}"] = _per_query_recall(retrieved[i].tolist(), gt[i], k)
|
|
147
|
+
row[f"ndcg@{k}"] = _per_query_ndcg(retrieved[i].tolist(), gt[i], k)
|
|
148
|
+
per_query.append(row)
|
|
149
|
+
payload["per_query"] = per_query
|
|
150
|
+
return payload
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def batch_metrics_summary(
|
|
154
|
+
reports: Sequence[dict[str, float]],
|
|
155
|
+
*,
|
|
156
|
+
include_std: bool = False,
|
|
157
|
+
) -> dict[str, float]:
|
|
158
|
+
"""Aggregate multiple retrieval reports into macro means."""
|
|
159
|
+
if len(reports) == 0:
|
|
160
|
+
raise ValueError("eval_error: reports cannot be empty")
|
|
161
|
+
keys = sorted(reports[0].keys())
|
|
162
|
+
for r in reports:
|
|
163
|
+
if sorted(r.keys()) != keys:
|
|
164
|
+
raise ValueError("eval_error: all reports must contain identical metric keys")
|
|
165
|
+
summary: dict[str, float] = {}
|
|
166
|
+
for key in keys:
|
|
167
|
+
vals = [float(r[key]) for r in reports]
|
|
168
|
+
arr = np.asarray(vals, dtype=np.float64)
|
|
169
|
+
summary[key] = float(np.mean(arr))
|
|
170
|
+
if include_std:
|
|
171
|
+
summary[f"{key}_std"] = float(np.std(arr))
|
|
172
|
+
summary["num_reports"] = float(len(reports))
|
|
173
|
+
return summary
|