redisbench-admin 0.11.64__py3-none-any.whl → 0.11.65__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- redisbench_admin/run/ann/pkg/.dockerignore +2 -0
- redisbench_admin/run/ann/pkg/.git +1 -0
- redisbench_admin/run/ann/pkg/.github/workflows/benchmarks.yml +100 -0
- redisbench_admin/run/ann/pkg/.gitignore +21 -0
- redisbench_admin/run/ann/pkg/LICENSE +21 -0
- redisbench_admin/run/ann/pkg/README.md +157 -0
- redisbench_admin/run/ann/pkg/algos.yaml +1294 -0
- redisbench_admin/run/ann/pkg/algosP.yaml +67 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/__init__.py +2 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/__init__.py +0 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/annoy.py +26 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/balltree.py +22 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/base.py +36 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/bruteforce.py +110 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/ckdtree.py +17 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/datasketch.py +29 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/definitions.py +187 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/diskann.py +190 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/dolphinnpy.py +31 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/dummy_algo.py +25 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/elasticsearch.py +107 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/elastiknn.py +124 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/faiss.py +124 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/faiss_gpu.py +61 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/faiss_hnsw.py +39 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/flann.py +27 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/hnswlib.py +36 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/kdtree.py +22 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/kgraph.py +39 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/lshf.py +25 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/milvus.py +99 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/mrpt.py +41 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/n2.py +28 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/nearpy.py +48 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/nmslib.py +74 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/onng_ngt.py +100 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/opensearchknn.py +107 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/panng_ngt.py +79 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/pinecone.py +39 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/puffinn.py +45 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/pynndescent.py +115 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/qg_ngt.py +102 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/redisearch.py +90 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/rpforest.py +20 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/scann.py +34 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/sptag.py +28 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/subprocess.py +246 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/vald.py +149 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/vecsim-hnsw.py +43 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/algorithms/vespa.py +47 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/constants.py +1 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/data.py +48 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/datasets.py +620 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/distance.py +53 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/main.py +325 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/__init__.py +2 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/metrics.py +183 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/plot_variants.py +17 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/plotting/utils.py +165 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/results.py +71 -0
- redisbench_admin/run/ann/pkg/ann_benchmarks/runner.py +333 -0
- redisbench_admin/run/ann/pkg/create_dataset.py +12 -0
- redisbench_admin/run/ann/pkg/create_hybrid_dataset.py +147 -0
- redisbench_admin/run/ann/pkg/create_text_to_image_ds.py +117 -0
- redisbench_admin/run/ann/pkg/create_website.py +272 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile +11 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.annoy +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.datasketch +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.diskann +29 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.diskann_pq +31 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.dolphinn +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.elasticsearch +45 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.elastiknn +61 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.faiss +18 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.flann +10 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.hnswlib +10 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.kgraph +6 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.mih +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.milvus +27 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.mrpt +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.n2 +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.nearpy +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.ngt +13 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.nmslib +10 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.opensearchknn +43 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.puffinn +6 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.pynndescent +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.redisearch +18 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.rpforest +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.scann +5 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.scipy +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.sklearn +4 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.sptag +30 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.vald +8 -0
- redisbench_admin/run/ann/pkg/install/Dockerfile.vespa +17 -0
- redisbench_admin/run/ann/pkg/install.py +70 -0
- redisbench_admin/run/ann/pkg/logging.conf +34 -0
- redisbench_admin/run/ann/pkg/multirun.py +298 -0
- redisbench_admin/run/ann/pkg/plot.py +159 -0
- redisbench_admin/run/ann/pkg/protocol/bf-runner +10 -0
- redisbench_admin/run/ann/pkg/protocol/bf-runner.py +204 -0
- redisbench_admin/run/ann/pkg/protocol/ext-add-query-metric.md +51 -0
- redisbench_admin/run/ann/pkg/protocol/ext-batch-queries.md +77 -0
- redisbench_admin/run/ann/pkg/protocol/ext-prepared-queries.md +77 -0
- redisbench_admin/run/ann/pkg/protocol/ext-query-parameters.md +47 -0
- redisbench_admin/run/ann/pkg/protocol/specification.md +194 -0
- redisbench_admin/run/ann/pkg/requirements.txt +14 -0
- redisbench_admin/run/ann/pkg/requirements_py38.txt +11 -0
- redisbench_admin/run/ann/pkg/results/fashion-mnist-784-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/results/gist-960-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/results/glove-100-angular.png +0 -0
- redisbench_admin/run/ann/pkg/results/glove-25-angular.png +0 -0
- redisbench_admin/run/ann/pkg/results/lastfm-64-dot.png +0 -0
- redisbench_admin/run/ann/pkg/results/mnist-784-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/results/nytimes-256-angular.png +0 -0
- redisbench_admin/run/ann/pkg/results/sift-128-euclidean.png +0 -0
- redisbench_admin/run/ann/pkg/run.py +12 -0
- redisbench_admin/run/ann/pkg/run_algorithm.py +3 -0
- redisbench_admin/run/ann/pkg/templates/chartjs.template +102 -0
- redisbench_admin/run/ann/pkg/templates/detail_page.html +23 -0
- redisbench_admin/run/ann/pkg/templates/general.html +58 -0
- redisbench_admin/run/ann/pkg/templates/latex.template +30 -0
- redisbench_admin/run/ann/pkg/templates/summary.html +60 -0
- redisbench_admin/run/ann/pkg/test/__init__.py +0 -0
- redisbench_admin/run/ann/pkg/test/test-jaccard.py +19 -0
- redisbench_admin/run/ann/pkg/test/test-metrics.py +99 -0
- redisbench_admin/run_remote/run_remote.py +1 -1
- {redisbench_admin-0.11.64.dist-info → redisbench_admin-0.11.65.dist-info}/METADATA +2 -5
- redisbench_admin-0.11.65.dist-info/RECORD +243 -0
- {redisbench_admin-0.11.64.dist-info → redisbench_admin-0.11.65.dist-info}/WHEEL +1 -1
- redisbench_admin-0.11.64.dist-info/RECORD +0 -117
- {redisbench_admin-0.11.64.dist-info/licenses → redisbench_admin-0.11.65.dist-info}/LICENSE +0 -0
- {redisbench_admin-0.11.64.dist-info → redisbench_admin-0.11.65.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
import vamanapy as vp
|
|
4
|
+
import numpy as np
|
|
5
|
+
import struct
|
|
6
|
+
import time
|
|
7
|
+
from ann_benchmarks.algorithms.base import BaseANN
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Vamana(BaseANN):
|
|
11
|
+
def __init__(self, metric, param):
|
|
12
|
+
self.metric = {'angular': 'cosine', 'euclidean': 'l2'}[metric]
|
|
13
|
+
self.l_build = int(param["l_build"])
|
|
14
|
+
self.max_outdegree = int(param["max_outdegree"])
|
|
15
|
+
self.alpha = float(param["alpha"])
|
|
16
|
+
print("Vamana: L_Build = " + str(self.l_build))
|
|
17
|
+
print("Vamana: R = " + str(self.max_outdegree))
|
|
18
|
+
print("Vamana: Alpha = " + str(self.alpha))
|
|
19
|
+
self.params = vp.Parameters()
|
|
20
|
+
self.params.set("L", self.l_build)
|
|
21
|
+
self.params.set("R", self.max_outdegree)
|
|
22
|
+
self.params.set("C", 750)
|
|
23
|
+
self.params.set("alpha", self.alpha)
|
|
24
|
+
self.params.set("saturate_graph", False)
|
|
25
|
+
self.params.set("num_threads", 1)
|
|
26
|
+
|
|
27
|
+
def fit(self, X):
|
|
28
|
+
|
|
29
|
+
def bin_to_float(binary):
|
|
30
|
+
return struct.unpack('!f',struct.pack('!I', int(binary, 2)))[0]
|
|
31
|
+
|
|
32
|
+
print("Vamana: Starting Fit...")
|
|
33
|
+
index_dir = 'indices'
|
|
34
|
+
|
|
35
|
+
if not os.path.exists(index_dir):
|
|
36
|
+
os.makedirs(index_dir)
|
|
37
|
+
|
|
38
|
+
data_path = os.path.join(index_dir, 'base.bin')
|
|
39
|
+
self.name = 'Vamana-{}-{}-{}'.format(self.l_build,
|
|
40
|
+
self.max_outdegree, self.alpha)
|
|
41
|
+
save_path = os.path.join(index_dir, self.name)
|
|
42
|
+
print('Vamana: Index Stored At: ' + save_path)
|
|
43
|
+
shape = [np.float32(bin_to_float('{:032b}'.format(X.shape[0]))),
|
|
44
|
+
np.float32(bin_to_float('{:032b}'.format(X.shape[1])))]
|
|
45
|
+
X = X.flatten()
|
|
46
|
+
X = np.insert(X, 0, shape)
|
|
47
|
+
X.tofile(data_path)
|
|
48
|
+
|
|
49
|
+
if not os.path.exists(save_path):
|
|
50
|
+
print('Vamana: Creating Index')
|
|
51
|
+
s = time.time()
|
|
52
|
+
if self.metric == 'l2':
|
|
53
|
+
index = vp.SinglePrecisionIndex(vp.Metric.FAST_L2, data_path)
|
|
54
|
+
elif self.metric == 'cosine':
|
|
55
|
+
index = vp.SinglePrecisionIndex(vp.Metric.INNER_PRODUCT, data_path)
|
|
56
|
+
else:
|
|
57
|
+
print('Vamana: Unknown Metric Error!')
|
|
58
|
+
index.build(self.params, [])
|
|
59
|
+
t = time.time()
|
|
60
|
+
print('Vamana: Index Build Time (sec) = ' + str(t - s))
|
|
61
|
+
index.save(save_path)
|
|
62
|
+
if os.path.exists(save_path):
|
|
63
|
+
print('Vamana: Loading Index: ' + str(save_path))
|
|
64
|
+
s = time.time()
|
|
65
|
+
if self.metric == 'l2':
|
|
66
|
+
self.index = vp.SinglePrecisionIndex(vp.Metric.FAST_L2, data_path)
|
|
67
|
+
elif self.metric == 'cosine':
|
|
68
|
+
self.index = vp.SinglePrecisionIndex(vp.Metric.INNER_PRODUCT, data_path)
|
|
69
|
+
else:
|
|
70
|
+
print('Vamana: Unknown Metric Error!')
|
|
71
|
+
self.index.load(file_name = save_path)
|
|
72
|
+
print("Vamana: Index Loaded")
|
|
73
|
+
self.index.optimize_graph()
|
|
74
|
+
print("Vamana: Graph Optimization Completed")
|
|
75
|
+
t = time.time()
|
|
76
|
+
print('Vamana: Index Load Time (sec) = ' + str(t - s))
|
|
77
|
+
else:
|
|
78
|
+
print("Vamana: Unexpected Index Build Time Error")
|
|
79
|
+
|
|
80
|
+
print('Vamana: End of Fit')
|
|
81
|
+
|
|
82
|
+
def set_query_arguments(self, l_search):
|
|
83
|
+
print("Vamana: L_Search = " + str(l_search))
|
|
84
|
+
self.l_search = l_search
|
|
85
|
+
|
|
86
|
+
def query(self, v, n):
|
|
87
|
+
return self.index.single_numpy_query(v, n, self.l_search)
|
|
88
|
+
|
|
89
|
+
def batch_query(self, X, n):
|
|
90
|
+
self.num_queries = X.shape[0]
|
|
91
|
+
self.result = self.index.batch_numpy_query(X, n, self.num_queries, self.l_search)
|
|
92
|
+
|
|
93
|
+
def get_batch_results(self):
|
|
94
|
+
return self.result.reshape((self.num_queries, self.result.shape[0] // self.num_queries))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class VamanaPQ(BaseANN):
|
|
98
|
+
def __init__(self, metric, param):
|
|
99
|
+
self.metric = {'angular': 'cosine', 'euclidean': 'l2'}[metric]
|
|
100
|
+
self.l_build = int(param["l_build"])
|
|
101
|
+
self.max_outdegree = int(param["max_outdegree"])
|
|
102
|
+
self.alpha = float(param["alpha"])
|
|
103
|
+
self.chunks = int(param["chunks"])
|
|
104
|
+
print("Vamana PQ: L_Build = " + str(self.l_build))
|
|
105
|
+
print("Vamana PQ: R = " + str(self.max_outdegree))
|
|
106
|
+
print("Vamana PQ: Alpha = " + str(self.alpha))
|
|
107
|
+
print("Vamana PQ: Chunks = " + str(self.chunks))
|
|
108
|
+
self.params = vp.Parameters()
|
|
109
|
+
self.params.set("L", self.l_build)
|
|
110
|
+
self.params.set("R", self.max_outdegree)
|
|
111
|
+
self.params.set("C", 750)
|
|
112
|
+
self.params.set("alpha", self.alpha)
|
|
113
|
+
self.params.set("saturate_graph", False)
|
|
114
|
+
self.params.set("num_chunks", self.chunks)
|
|
115
|
+
self.params.set("num_threads", 1)
|
|
116
|
+
|
|
117
|
+
def fit(self, X):
|
|
118
|
+
|
|
119
|
+
def bin_to_float(binary):
|
|
120
|
+
return struct.unpack('!f',struct.pack('!I', int(binary, 2)))[0]
|
|
121
|
+
|
|
122
|
+
print("Vamana PQ: Starting Fit...")
|
|
123
|
+
index_dir = 'indices'
|
|
124
|
+
|
|
125
|
+
if self.chunks > X.shape[1]:
|
|
126
|
+
raise ValueError
|
|
127
|
+
|
|
128
|
+
if not os.path.exists(index_dir):
|
|
129
|
+
os.makedirs(index_dir)
|
|
130
|
+
|
|
131
|
+
data_path = os.path.join(index_dir, 'base.bin')
|
|
132
|
+
pq_path = os.path.join(index_dir, 'pq_memory_index')
|
|
133
|
+
self.name = 'VamanaPQ-{}-{}-{}'.format(self.l_build,
|
|
134
|
+
self.max_outdegree, self.alpha)
|
|
135
|
+
save_path = os.path.join(index_dir, self.name)
|
|
136
|
+
print('Vamana PQ: Index Stored At: ' + save_path)
|
|
137
|
+
shape = [np.float32(bin_to_float('{:032b}'.format(X.shape[0]))),
|
|
138
|
+
np.float32(bin_to_float('{:032b}'.format(X.shape[1])))]
|
|
139
|
+
X = X.flatten()
|
|
140
|
+
X = np.insert(X, 0, shape)
|
|
141
|
+
X.tofile(data_path)
|
|
142
|
+
|
|
143
|
+
if not os.path.exists(save_path):
|
|
144
|
+
print('Vamana PQ: Creating Index')
|
|
145
|
+
s = time.time()
|
|
146
|
+
if self.metric == 'l2':
|
|
147
|
+
index = vp.SinglePrecisionIndex(vp.Metric.FAST_L2, data_path)
|
|
148
|
+
elif self.metric == 'cosine':
|
|
149
|
+
index = vp.SinglePrecisionIndex(vp.Metric.INNER_PRODUCT, data_path)
|
|
150
|
+
else:
|
|
151
|
+
print('Vamana PQ: Unknown Metric Error!')
|
|
152
|
+
index.pq_build(data_path, pq_path, self.params)
|
|
153
|
+
t = time.time()
|
|
154
|
+
print('Vamana PQ: Index Build Time (sec) = ' + str(t - s))
|
|
155
|
+
index.save(save_path)
|
|
156
|
+
if os.path.exists(save_path):
|
|
157
|
+
print('Vamana PQ: Loading Index: ' + str(save_path))
|
|
158
|
+
s = time.time()
|
|
159
|
+
if self.metric == 'l2':
|
|
160
|
+
self.index = vp.SinglePrecisionIndex(vp.Metric.FAST_L2, data_path)
|
|
161
|
+
elif self.metric == 'cosine':
|
|
162
|
+
self.index = vp.SinglePrecisionIndex(vp.Metric.INNER_PRODUCT, data_path)
|
|
163
|
+
else:
|
|
164
|
+
print('Vamana PQ: Unknown Metric Error!')
|
|
165
|
+
self.index.load(file_name = save_path)
|
|
166
|
+
print("Vamana PQ: Index Loaded")
|
|
167
|
+
self.index.pq_load(pq_prefix_path = pq_path)
|
|
168
|
+
print("Vamana PQ: PQ Data Loaded")
|
|
169
|
+
self.index.optimize_graph()
|
|
170
|
+
print("Vamana PQ: Graph Optimization Completed")
|
|
171
|
+
t = time.time()
|
|
172
|
+
print('Vamana PQ: Index Load Time (sec) = ' + str(t - s))
|
|
173
|
+
else:
|
|
174
|
+
print("Vamana PQ: Unexpected Index Build Time Error")
|
|
175
|
+
|
|
176
|
+
print('Vamana PQ: End of Fit')
|
|
177
|
+
|
|
178
|
+
def set_query_arguments(self, l_search):
|
|
179
|
+
print("Vamana PQ: L_Search = " + str(l_search))
|
|
180
|
+
self.l_search = l_search
|
|
181
|
+
|
|
182
|
+
def query(self, v, n):
|
|
183
|
+
return self.index.pq_single_numpy_query(v, n, self.l_search)
|
|
184
|
+
|
|
185
|
+
def batch_query(self, X, n):
|
|
186
|
+
self.num_queries = X.shape[0]
|
|
187
|
+
self.result = self.index.pq_batch_numpy_query(X, n, self.num_queries, self.l_search)
|
|
188
|
+
|
|
189
|
+
def get_batch_results(self):
|
|
190
|
+
return self.result.reshape((self.num_queries, self.result.shape[0] // self.num_queries))
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
import sys
|
|
3
|
+
sys.path.append("install/lib-dolphinnpy") # noqa
|
|
4
|
+
import numpy
|
|
5
|
+
import ctypes
|
|
6
|
+
from dolphinn import Dolphinn
|
|
7
|
+
from utils import findmean, isotropize
|
|
8
|
+
from ann_benchmarks.algorithms.base import BaseANN
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DolphinnPy(BaseANN):
|
|
12
|
+
def __init__(self, num_probes):
|
|
13
|
+
self.name = 'Dolphinn(num_probes={} )'.format(num_probes)
|
|
14
|
+
self.num_probes = num_probes
|
|
15
|
+
self.m = 1
|
|
16
|
+
self._index = None
|
|
17
|
+
|
|
18
|
+
def fit(self, X):
|
|
19
|
+
if X.dtype != numpy.float32:
|
|
20
|
+
X = numpy.array(X, dtype=numpy.float32)
|
|
21
|
+
d = X.shape[1]
|
|
22
|
+
self.m = findmean(X, d, 10)
|
|
23
|
+
X = isotropize(X, d, self.m)
|
|
24
|
+
hypercube_dim = int(numpy.log2(len(X))) - 2
|
|
25
|
+
self._index = Dolphinn(X, d, hypercube_dim)
|
|
26
|
+
|
|
27
|
+
def query(self, v, n):
|
|
28
|
+
q = numpy.array([v])
|
|
29
|
+
q = isotropize(q, len(v), self.m)
|
|
30
|
+
res = self._index.queries(q, n, self.num_probes)
|
|
31
|
+
return res[0]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
import numpy as np
|
|
3
|
+
from ann_benchmarks.algorithms.base import BaseANN
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DummyAlgoMt(BaseANN):
|
|
7
|
+
def __init__(self, metric):
|
|
8
|
+
self.name = 'DummyAlgoMultiThread'
|
|
9
|
+
|
|
10
|
+
def fit(self, X):
|
|
11
|
+
self.len = len(X) - 1
|
|
12
|
+
|
|
13
|
+
def query(self, v, n):
|
|
14
|
+
return np.random.randint(self.len, size=n)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DummyAlgoSt(BaseANN):
|
|
18
|
+
def __init__(self, metric):
|
|
19
|
+
self.name = 'DummyAlgoSingleThread'
|
|
20
|
+
|
|
21
|
+
def fit(self, X):
|
|
22
|
+
self.len = len(X) - 1
|
|
23
|
+
|
|
24
|
+
def query(self, v, n):
|
|
25
|
+
return np.random.randint(self.len, size=n)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ann-benchmarks interfaces for Elasticsearch.
|
|
3
|
+
Note that this requires X-Pack, which is not included in the OSS version of Elasticsearch.
|
|
4
|
+
"""
|
|
5
|
+
import logging
|
|
6
|
+
from time import sleep
|
|
7
|
+
from os import environ
|
|
8
|
+
from urllib.error import URLError
|
|
9
|
+
|
|
10
|
+
from elasticsearch import Elasticsearch, BadRequestError
|
|
11
|
+
from elasticsearch.helpers import bulk
|
|
12
|
+
from elastic_transport.client_utils import DEFAULT
|
|
13
|
+
|
|
14
|
+
from ann_benchmarks.algorithms.base import BaseANN
|
|
15
|
+
|
|
16
|
+
# Configure the elasticsearch logger.
|
|
17
|
+
# By default, it writes an INFO statement for every request.
|
|
18
|
+
logging.getLogger("elasticsearch").setLevel(logging.WARN)
|
|
19
|
+
|
|
20
|
+
# Uncomment these lines if you want to see timing for every HTTP request and its duration.
|
|
21
|
+
# logging.basicConfig(level=logging.INFO)
|
|
22
|
+
# logging.getLogger("elasticsearch").setLevel(logging.INFO)
|
|
23
|
+
|
|
24
|
+
def es_wait(es):
|
|
25
|
+
print("Waiting for elasticsearch health endpoint...")
|
|
26
|
+
for i in range(30):
|
|
27
|
+
try:
|
|
28
|
+
res = es.cluster.health(wait_for_status='yellow', timeout='1s')
|
|
29
|
+
if not res['timed_out']: # then status is OK
|
|
30
|
+
print("Elasticsearch is ready")
|
|
31
|
+
return
|
|
32
|
+
except URLError:
|
|
33
|
+
pass
|
|
34
|
+
sleep(1)
|
|
35
|
+
raise RuntimeError("Failed to connect to elasticsearch server")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ElasticsearchScriptScoreQuery(BaseANN):
|
|
39
|
+
"""
|
|
40
|
+
KNN using the Elasticsearch dense_vector datatype and script score functions.
|
|
41
|
+
- Dense vector field type: https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html
|
|
42
|
+
- Dense vector queries: https://www.elastic.co/guide/en/elasticsearch/reference/master/query-dsl-script-score-query.html
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, metric: str, dimension: int, conn_params, method_param):
|
|
46
|
+
self.name = f"elasticsearch-script-score-query_metric={metric}_dimension={dimension}_params{method_param}"
|
|
47
|
+
self.metric = {"euclidean": 'l2_norm', "angular": 'cosine'}[metric]
|
|
48
|
+
self.method_param = method_param
|
|
49
|
+
self.dimension = dimension
|
|
50
|
+
self.timeout = 60 * 60
|
|
51
|
+
h = conn_params['host'] if conn_params['host'] is not None else 'localhost'
|
|
52
|
+
p = conn_params['port'] if conn_params['port'] is not None else '9200'
|
|
53
|
+
u = conn_params['user'] if conn_params['user'] is not None else 'elastic'
|
|
54
|
+
a = conn_params['auth'] if conn_params['auth'] is not None else ''
|
|
55
|
+
self.index = "ann_benchmark"
|
|
56
|
+
self.shards = conn_params['shards']
|
|
57
|
+
try:
|
|
58
|
+
self.es = Elasticsearch(f"http://{h}:{p}", request_timeout=self.timeout, basic_auth=(u, a), refresh_interval=-1)
|
|
59
|
+
self.es.info()
|
|
60
|
+
except Exception:
|
|
61
|
+
self.es = Elasticsearch(f"https://{h}:{p}", request_timeout=self.timeout, basic_auth=(u, a), ca_certs=environ.get('ELASTIC_CA', DEFAULT))
|
|
62
|
+
self.batch_res = []
|
|
63
|
+
es_wait(self.es)
|
|
64
|
+
|
|
65
|
+
def fit(self, X):
|
|
66
|
+
mappings = dict(
|
|
67
|
+
properties=dict(
|
|
68
|
+
id=dict(type="keyword", store=True),
|
|
69
|
+
vec=dict(
|
|
70
|
+
type="dense_vector",
|
|
71
|
+
dims=self.dimension,
|
|
72
|
+
similarity=self.metric,
|
|
73
|
+
index=True,
|
|
74
|
+
index_options=self.method_param
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
try:
|
|
79
|
+
self.es.indices.create(index=self.index, mappings=mappings, settings=dict(number_of_shards=self.shards, number_of_replicas=0))
|
|
80
|
+
except BadRequestError as e:
|
|
81
|
+
if 'resource_already_exists_exception' not in e.message: raise e
|
|
82
|
+
|
|
83
|
+
def gen():
|
|
84
|
+
for i, vec in enumerate(X):
|
|
85
|
+
yield { "_op_type": "index", "_index": self.index, "vec": vec.tolist(), 'id': str(i) }
|
|
86
|
+
|
|
87
|
+
(_, errors) = bulk(self.es, gen(), chunk_size=500, max_retries=9)
|
|
88
|
+
assert len(errors) == 0, errors
|
|
89
|
+
|
|
90
|
+
self.es.indices.refresh(index=self.index)
|
|
91
|
+
self.es.indices.forcemerge(index=self.index, max_num_segments=1)
|
|
92
|
+
|
|
93
|
+
def set_query_arguments(self, ef):
|
|
94
|
+
self.ef = ef
|
|
95
|
+
|
|
96
|
+
def query(self, q, n):
|
|
97
|
+
knn = dict(field='vec', query_vector=q.tolist(), k=n, num_candidates=self.ef)
|
|
98
|
+
res = self.es.knn_search(index=self.index, knn=knn, source=False, docvalue_fields=['id'],
|
|
99
|
+
stored_fields="_none_", filter_path=["hits.hits.fields.id"])
|
|
100
|
+
return [int(h['fields']['id'][0]) for h in res['hits']['hits']]
|
|
101
|
+
|
|
102
|
+
def batch_query(self, X, n):
|
|
103
|
+
self.batch_res = [self.query(q, n) for q in X]
|
|
104
|
+
|
|
105
|
+
def get_batch_results(self):
|
|
106
|
+
return self.batch_res
|
|
107
|
+
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ann-benchmarks interfaces for elastiknn: https://github.com/alexklibisz/elastiknn
|
|
3
|
+
Uses the elastiknn python client
|
|
4
|
+
To install a local copy of the client, run `pip install --upgrade -e /path/to/elastiknn/client-python/`
|
|
5
|
+
To monitor the Elasticsearch JVM using Visualvm, add `ports={ "8097": 8097 }` to the `containers.run` call in runner.py.
|
|
6
|
+
"""
|
|
7
|
+
from sys import stderr
|
|
8
|
+
from urllib.error import URLError
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from elastiknn.api import Vec
|
|
12
|
+
from elastiknn.models import ElastiknnModel
|
|
13
|
+
from elastiknn.utils import dealias_metric
|
|
14
|
+
|
|
15
|
+
from ann_benchmarks.algorithms.base import BaseANN
|
|
16
|
+
|
|
17
|
+
from urllib.request import Request, urlopen
|
|
18
|
+
from time import sleep, perf_counter
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
|
|
22
|
+
# Mute the elasticsearch logger.
|
|
23
|
+
# By default, it writes an INFO statement for every request.
|
|
24
|
+
logging.getLogger("elasticsearch").setLevel(logging.WARN)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def es_wait():
|
|
28
|
+
print("Waiting for elasticsearch health endpoint...")
|
|
29
|
+
req = Request("http://localhost:9200/_cluster/health?wait_for_status=yellow&timeout=1s")
|
|
30
|
+
for i in range(30):
|
|
31
|
+
try:
|
|
32
|
+
res = urlopen(req)
|
|
33
|
+
if res.getcode() == 200:
|
|
34
|
+
print("Elasticsearch is ready")
|
|
35
|
+
return
|
|
36
|
+
except URLError:
|
|
37
|
+
pass
|
|
38
|
+
sleep(1)
|
|
39
|
+
raise RuntimeError("Failed to connect to local elasticsearch")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Exact(BaseANN):
|
|
43
|
+
|
|
44
|
+
def __init__(self, metric: str, dimension: int):
|
|
45
|
+
self.name = f"eknn-exact-metric={metric}_dimension={dimension}"
|
|
46
|
+
self.metric = metric
|
|
47
|
+
self.dimension = dimension
|
|
48
|
+
self.model = ElastiknnModel("exact", dealias_metric(metric))
|
|
49
|
+
self.batch_res = None
|
|
50
|
+
es_wait()
|
|
51
|
+
|
|
52
|
+
def _handle_sparse(self, X):
|
|
53
|
+
# convert list of lists of indices to sparse vectors.
|
|
54
|
+
return [Vec.SparseBool(x, self.dimension) for x in X]
|
|
55
|
+
|
|
56
|
+
def fit(self, X):
|
|
57
|
+
if self.metric in {'jaccard', 'hamming'}:
|
|
58
|
+
return self.model.fit(self._handle_sparse(X), shards=1)[0]
|
|
59
|
+
else:
|
|
60
|
+
return self.model.fit(X, shards=1)
|
|
61
|
+
|
|
62
|
+
def query(self, q, n):
|
|
63
|
+
if self.metric in {'jaccard', 'hamming'}:
|
|
64
|
+
return self.model.kneighbors(self._handle_sparse([q]), n)[0]
|
|
65
|
+
else:
|
|
66
|
+
return self.model.kneighbors(np.expand_dims(q, 0), n)[0]
|
|
67
|
+
|
|
68
|
+
def batch_query(self, X, n):
|
|
69
|
+
if self.metric in {'jaccard', 'hamming'}:
|
|
70
|
+
self.batch_res = self.model.kneighbors(self._handle_sparse(X), n)
|
|
71
|
+
else:
|
|
72
|
+
self.batch_res = self.model.kneighbors(X, n)
|
|
73
|
+
|
|
74
|
+
def get_batch_results(self):
|
|
75
|
+
return self.batch_res
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class L2Lsh(BaseANN):
|
|
79
|
+
|
|
80
|
+
def __init__(self, L: int, k: int, w: int):
|
|
81
|
+
self.name_prefix = f"eknn-l2lsh-L={L}-k={k}-w={w}"
|
|
82
|
+
self.name = None # set based on query args.
|
|
83
|
+
self.model = ElastiknnModel("lsh", "l2", mapping_params=dict(L=L, k=k, w=w))
|
|
84
|
+
self.X_max = 1.0
|
|
85
|
+
self.query_params = dict()
|
|
86
|
+
self.batch_res = None
|
|
87
|
+
self.sum_query_dur = 0
|
|
88
|
+
self.num_queries = 0
|
|
89
|
+
es_wait()
|
|
90
|
+
|
|
91
|
+
def fit(self, X):
|
|
92
|
+
print(f"{self.name_prefix}: indexing {len(X)} vectors")
|
|
93
|
+
|
|
94
|
+
# I found it's best to scale the vectors into [0, 1], i.e. divide by the max.
|
|
95
|
+
self.X_max = X.max()
|
|
96
|
+
return self.model.fit(X / self.X_max, shards=1)
|
|
97
|
+
|
|
98
|
+
def set_query_arguments(self, candidates: int, probes: int):
|
|
99
|
+
# This gets called when starting a new batch of queries.
|
|
100
|
+
# Update the name and model's query parameters based on the given params.
|
|
101
|
+
self.name = f"{self.name_prefix}_candidates={candidates}_probes={probes}"
|
|
102
|
+
self.model.set_query_params(dict(candidates=candidates, probes=probes))
|
|
103
|
+
# Reset the counters.
|
|
104
|
+
self.num_queries = 0
|
|
105
|
+
self.sum_query_dur = 0
|
|
106
|
+
|
|
107
|
+
def query(self, q, n):
|
|
108
|
+
# If QPS after 100 queries is < 10, this setting is bad and won't complete within the default timeout.
|
|
109
|
+
if self.num_queries > 100 and self.num_queries / self.sum_query_dur < 10:
|
|
110
|
+
print("Throughput after 100 queries is less than 10 q/s. Terminating to avoid wasteful computation.", flush=True)
|
|
111
|
+
exit(0)
|
|
112
|
+
else:
|
|
113
|
+
t0 = perf_counter()
|
|
114
|
+
res = self.model.kneighbors(np.expand_dims(q, 0) / self.X_max, n)[0]
|
|
115
|
+
dur = (perf_counter() - t0)
|
|
116
|
+
self.sum_query_dur += dur
|
|
117
|
+
self.num_queries += 1
|
|
118
|
+
return res
|
|
119
|
+
|
|
120
|
+
def batch_query(self, X, n):
|
|
121
|
+
self.batch_res = self.model.kneighbors(X, n)
|
|
122
|
+
|
|
123
|
+
def get_batch_results(self):
|
|
124
|
+
return self.batch_res
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
import sys
|
|
3
|
+
sys.path.append("install/lib-faiss") # noqa
|
|
4
|
+
import numpy
|
|
5
|
+
import sklearn.preprocessing
|
|
6
|
+
import ctypes
|
|
7
|
+
import faiss
|
|
8
|
+
from ann_benchmarks.algorithms.base import BaseANN
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Faiss(BaseANN):
|
|
12
|
+
def query(self, v, n):
|
|
13
|
+
if self._metric == 'angular':
|
|
14
|
+
v /= numpy.linalg.norm(v)
|
|
15
|
+
D, I = self.index.search(numpy.expand_dims(
|
|
16
|
+
v, axis=0).astype(numpy.float32), n)
|
|
17
|
+
return I[0]
|
|
18
|
+
|
|
19
|
+
def batch_query(self, X, n):
|
|
20
|
+
if self._metric == 'angular':
|
|
21
|
+
X /= numpy.linalg.norm(X)
|
|
22
|
+
self.res = self.index.search(X.astype(numpy.float32), n)
|
|
23
|
+
|
|
24
|
+
def get_batch_results(self):
|
|
25
|
+
D, L = self.res
|
|
26
|
+
res = []
|
|
27
|
+
for i in range(len(D)):
|
|
28
|
+
r = []
|
|
29
|
+
for l, d in zip(L[i], D[i]):
|
|
30
|
+
if l != -1:
|
|
31
|
+
r.append(l)
|
|
32
|
+
res.append(r)
|
|
33
|
+
return res
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class FaissLSH(Faiss):
|
|
37
|
+
def __init__(self, metric, n_bits):
|
|
38
|
+
self._n_bits = n_bits
|
|
39
|
+
self.index = None
|
|
40
|
+
self._metric = metric
|
|
41
|
+
self.name = 'FaissLSH(n_bits={})'.format(self._n_bits)
|
|
42
|
+
|
|
43
|
+
def fit(self, X):
|
|
44
|
+
if X.dtype != numpy.float32:
|
|
45
|
+
X = X.astype(numpy.float32)
|
|
46
|
+
f = X.shape[1]
|
|
47
|
+
self.index = faiss.IndexLSH(f, self._n_bits)
|
|
48
|
+
self.index.train(X)
|
|
49
|
+
self.index.add(X)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class FaissIVF(Faiss):
|
|
53
|
+
def __init__(self, metric, n_list):
|
|
54
|
+
self._n_list = n_list
|
|
55
|
+
self._metric = metric
|
|
56
|
+
|
|
57
|
+
def fit(self, X):
|
|
58
|
+
if self._metric == 'angular':
|
|
59
|
+
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
|
|
60
|
+
|
|
61
|
+
if X.dtype != numpy.float32:
|
|
62
|
+
X = X.astype(numpy.float32)
|
|
63
|
+
|
|
64
|
+
self.quantizer = faiss.IndexFlatL2(X.shape[1])
|
|
65
|
+
index = faiss.IndexIVFFlat(
|
|
66
|
+
self.quantizer, X.shape[1], self._n_list, faiss.METRIC_L2)
|
|
67
|
+
index.train(X)
|
|
68
|
+
index.add(X)
|
|
69
|
+
self.index = index
|
|
70
|
+
|
|
71
|
+
def set_query_arguments(self, n_probe):
|
|
72
|
+
faiss.cvar.indexIVF_stats.reset()
|
|
73
|
+
self._n_probe = n_probe
|
|
74
|
+
self.index.nprobe = self._n_probe
|
|
75
|
+
|
|
76
|
+
def get_additional(self):
|
|
77
|
+
return {"dist_comps": faiss.cvar.indexIVF_stats.ndis + # noqa
|
|
78
|
+
faiss.cvar.indexIVF_stats.nq * self._n_list}
|
|
79
|
+
|
|
80
|
+
def __str__(self):
|
|
81
|
+
return 'FaissIVF(n_list=%d, n_probe=%d)' % (self._n_list,
|
|
82
|
+
self._n_probe)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class FaissIVFPQfs(Faiss):
|
|
86
|
+
def __init__(self, metric, n_list):
|
|
87
|
+
self._n_list = n_list
|
|
88
|
+
self._metric = metric
|
|
89
|
+
|
|
90
|
+
def fit(self, X):
|
|
91
|
+
if X.dtype != numpy.float32:
|
|
92
|
+
X = X.astype(numpy.float32)
|
|
93
|
+
if self._metric == 'angular':
|
|
94
|
+
faiss.normalize_L2(X)
|
|
95
|
+
|
|
96
|
+
d = X.shape[1]
|
|
97
|
+
faiss_metric = faiss.METRIC_INNER_PRODUCT if self._metric == 'angular' else faiss.METRIC_L2
|
|
98
|
+
factory_string = f"IVF{self._n_list},PQ{d//2}x4fs"
|
|
99
|
+
index = faiss.index_factory(d, factory_string, faiss_metric)
|
|
100
|
+
index.train(X)
|
|
101
|
+
index.add(X)
|
|
102
|
+
index_refine = faiss.IndexRefineFlat(index, faiss.swig_ptr(X))
|
|
103
|
+
self.base_index = index
|
|
104
|
+
self.refine_index = index_refine
|
|
105
|
+
|
|
106
|
+
def set_query_arguments(self, n_probe, k_reorder):
|
|
107
|
+
faiss.cvar.indexIVF_stats.reset()
|
|
108
|
+
self._n_probe = n_probe
|
|
109
|
+
self._k_reorder = k_reorder
|
|
110
|
+
self.base_index.nprobe = self._n_probe
|
|
111
|
+
self.refine_index.k_factor = self._k_reorder
|
|
112
|
+
if self._k_reorder == 0:
|
|
113
|
+
self.index = self.base_index
|
|
114
|
+
else:
|
|
115
|
+
self.index = self.refine_index
|
|
116
|
+
|
|
117
|
+
def get_additional(self):
|
|
118
|
+
return {"dist_comps": faiss.cvar.indexIVF_stats.ndis + # noqa
|
|
119
|
+
faiss.cvar.indexIVF_stats.nq * self._n_list}
|
|
120
|
+
|
|
121
|
+
def __str__(self):
|
|
122
|
+
return 'FaissIVFPQfs(n_list=%d, n_probe=%d, k_reorder=%d)' % (self._n_list,
|
|
123
|
+
self._n_probe,
|
|
124
|
+
self._k_reorder)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
import sys
|
|
3
|
+
# Assumes local installation of FAISS
|
|
4
|
+
sys.path.append("faiss") # noqa
|
|
5
|
+
import numpy
|
|
6
|
+
import ctypes
|
|
7
|
+
import faiss
|
|
8
|
+
from ann_benchmarks.algorithms.base import BaseANN
|
|
9
|
+
|
|
10
|
+
# Implementation based on
|
|
11
|
+
# https://github.com/facebookresearch/faiss/blob/master/benchs/bench_gpu_sift1m.py # noqa
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FaissGPU(BaseANN):
|
|
15
|
+
def __init__(self, n_bits, n_probes):
|
|
16
|
+
self.name = 'FaissGPU(n_bits={}, n_probes={})'.format(
|
|
17
|
+
n_bits, n_probes)
|
|
18
|
+
self._n_bits = n_bits
|
|
19
|
+
self._n_probes = n_probes
|
|
20
|
+
self._res = faiss.StandardGpuResources()
|
|
21
|
+
self._index = None
|
|
22
|
+
|
|
23
|
+
def fit(self, X):
|
|
24
|
+
X = X.astype(numpy.float32)
|
|
25
|
+
self._index = faiss.GpuIndexIVFFlat(self._res, len(X[0]), self._n_bits,
|
|
26
|
+
faiss.METRIC_L2)
|
|
27
|
+
# self._index = faiss.index_factory(len(X[0]),
|
|
28
|
+
# "IVF%d,Flat" % self._n_bits)
|
|
29
|
+
# co = faiss.GpuClonerOptions()
|
|
30
|
+
# co.useFloat16 = True
|
|
31
|
+
# self._index = faiss.index_cpu_to_gpu(self._res, 0,
|
|
32
|
+
# self._index, co)
|
|
33
|
+
self._index.train(X)
|
|
34
|
+
self._index.add(X)
|
|
35
|
+
self._index.setNumProbes(self._n_probes)
|
|
36
|
+
|
|
37
|
+
def query(self, v, n):
|
|
38
|
+
return [label for label, _ in self.query_with_distances(v, n)]
|
|
39
|
+
|
|
40
|
+
def query_with_distances(self, v, n):
|
|
41
|
+
v = v.astype(numpy.float32).reshape(1, -1)
|
|
42
|
+
distances, labels = self._index.search(v, n)
|
|
43
|
+
r = []
|
|
44
|
+
for l, d in zip(labels[0], distances[0]):
|
|
45
|
+
if l != -1:
|
|
46
|
+
r.append((l, d))
|
|
47
|
+
return r
|
|
48
|
+
|
|
49
|
+
def batch_query(self, X, n):
|
|
50
|
+
self.res = self._index.search(X.astype(numpy.float32), n)
|
|
51
|
+
|
|
52
|
+
def get_batch_results(self):
|
|
53
|
+
D, L = self.res
|
|
54
|
+
res = []
|
|
55
|
+
for i in range(len(D)):
|
|
56
|
+
r = []
|
|
57
|
+
for l, d in zip(L[i], D[i]):
|
|
58
|
+
if l != -1:
|
|
59
|
+
r.append(l)
|
|
60
|
+
res.append(r)
|
|
61
|
+
return res
|