faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -5,15 +5,25 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
#include <faiss/impl/HNSW.h>
|
11
9
|
|
10
|
+
#include <cstddef>
|
12
11
|
#include <string>
|
13
12
|
|
14
13
|
#include <faiss/impl/AuxIndexStructures.h>
|
15
14
|
#include <faiss/impl/DistanceComputer.h>
|
16
15
|
#include <faiss/impl/IDSelector.h>
|
16
|
+
#include <faiss/impl/ResultHandler.h>
|
17
|
+
#include <faiss/utils/prefetch.h>
|
18
|
+
|
19
|
+
#include <faiss/impl/platform_macros.h>
|
20
|
+
|
21
|
+
#ifdef __AVX2__
|
22
|
+
#include <immintrin.h>
|
23
|
+
|
24
|
+
#include <limits>
|
25
|
+
#include <type_traits>
|
26
|
+
#endif
|
17
27
|
|
18
28
|
namespace faiss {
|
19
29
|
|
@@ -101,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
|
|
101
111
|
level,
|
102
112
|
nb_neighbors(level));
|
103
113
|
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
|
104
|
-
#pragma omp parallel for reduction(
|
105
|
-
|
114
|
+
#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
|
115
|
+
reduction(+ : tot_reciprocal) reduction(+ : n_node)
|
106
116
|
for (int i = 0; i < levels.size(); i++) {
|
107
117
|
if (levels[i] > level) {
|
108
118
|
n_node++;
|
@@ -206,13 +216,13 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
|
|
206
216
|
if (pt_level > max_level)
|
207
217
|
max_level = pt_level;
|
208
218
|
offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
|
209
|
-
neighbors.resize(offsets.back(), -1);
|
210
219
|
}
|
220
|
+
neighbors.resize(offsets.back(), -1);
|
211
221
|
|
212
222
|
return max_level;
|
213
223
|
}
|
214
224
|
|
215
|
-
/** Enumerate vertices from
|
225
|
+
/** Enumerate vertices from nearest to farthest from query, keep a
|
216
226
|
* neighbor only if there is no previous neighbor that is closer to
|
217
227
|
* that vertex than the query.
|
218
228
|
*/
|
@@ -220,7 +230,14 @@ void HNSW::shrink_neighbor_list(
|
|
220
230
|
DistanceComputer& qdis,
|
221
231
|
std::priority_queue<NodeDistFarther>& input,
|
222
232
|
std::vector<NodeDistFarther>& output,
|
223
|
-
int max_size
|
233
|
+
int max_size,
|
234
|
+
bool keep_max_size_level0) {
|
235
|
+
// This prevents number of neighbors at
|
236
|
+
// level 0 from being shrunk to less than 2 * M.
|
237
|
+
// This is essential in making sure
|
238
|
+
// `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
|
239
|
+
std::vector<NodeDistFarther> outsiders;
|
240
|
+
|
224
241
|
while (input.size() > 0) {
|
225
242
|
NodeDistFarther v1 = input.top();
|
226
243
|
input.pop();
|
@@ -241,8 +258,15 @@ void HNSW::shrink_neighbor_list(
|
|
241
258
|
if (output.size() >= max_size) {
|
242
259
|
return;
|
243
260
|
}
|
261
|
+
} else if (keep_max_size_level0) {
|
262
|
+
outsiders.push_back(v1);
|
244
263
|
}
|
245
264
|
}
|
265
|
+
size_t idx = 0;
|
266
|
+
while (keep_max_size_level0 && (output.size() < max_size) &&
|
267
|
+
(idx < outsiders.size())) {
|
268
|
+
output.push_back(outsiders[idx++]);
|
269
|
+
}
|
246
270
|
}
|
247
271
|
|
248
272
|
namespace {
|
@@ -259,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
|
|
259
283
|
void shrink_neighbor_list(
|
260
284
|
DistanceComputer& qdis,
|
261
285
|
std::priority_queue<NodeDistCloser>& resultSet1,
|
262
|
-
int max_size
|
286
|
+
int max_size,
|
287
|
+
bool keep_max_size_level0 = false) {
|
263
288
|
if (resultSet1.size() < max_size) {
|
264
289
|
return;
|
265
290
|
}
|
@@ -271,7 +296,8 @@ void shrink_neighbor_list(
|
|
271
296
|
resultSet1.pop();
|
272
297
|
}
|
273
298
|
|
274
|
-
HNSW::shrink_neighbor_list(
|
299
|
+
HNSW::shrink_neighbor_list(
|
300
|
+
qdis, resultSet, returnlist, max_size, keep_max_size_level0);
|
275
301
|
|
276
302
|
for (NodeDistFarther curen2 : returnlist) {
|
277
303
|
resultSet1.emplace(curen2.d, curen2.id);
|
@@ -285,7 +311,8 @@ void add_link(
|
|
285
311
|
DistanceComputer& qdis,
|
286
312
|
storage_idx_t src,
|
287
313
|
storage_idx_t dest,
|
288
|
-
int level
|
314
|
+
int level,
|
315
|
+
bool keep_max_size_level0 = false) {
|
289
316
|
size_t begin, end;
|
290
317
|
hnsw.neighbor_range(src, level, &begin, &end);
|
291
318
|
if (hnsw.neighbors[end - 1] == -1) {
|
@@ -310,7 +337,7 @@ void add_link(
|
|
310
337
|
resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
|
311
338
|
}
|
312
339
|
|
313
|
-
shrink_neighbor_list(qdis, resultSet, end - begin);
|
340
|
+
shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
|
314
341
|
|
315
342
|
// ...and back
|
316
343
|
size_t i = begin;
|
@@ -333,6 +360,9 @@ void search_neighbors_to_add(
|
|
333
360
|
float d_entry_point,
|
334
361
|
int level,
|
335
362
|
VisitedTable& vt) {
|
363
|
+
// selects a version
|
364
|
+
const bool reference_version = false;
|
365
|
+
|
336
366
|
// top is nearest candidate
|
337
367
|
std::priority_queue<NodeDistFarther> candidates;
|
338
368
|
|
@@ -354,59 +384,90 @@ void search_neighbors_to_add(
|
|
354
384
|
// loop over neighbors
|
355
385
|
size_t begin, end;
|
356
386
|
hnsw.neighbor_range(currNode, level, &begin, &end);
|
357
|
-
for (size_t i = begin; i < end; i++) {
|
358
|
-
storage_idx_t nodeId = hnsw.neighbors[i];
|
359
|
-
if (nodeId < 0)
|
360
|
-
break;
|
361
|
-
if (vt.get(nodeId))
|
362
|
-
continue;
|
363
|
-
vt.set(nodeId);
|
364
|
-
|
365
|
-
float dis = qdis(nodeId);
|
366
|
-
NodeDistFarther evE1(dis, nodeId);
|
367
387
|
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
388
|
+
// select a version, based on a flag
|
389
|
+
if (reference_version) {
|
390
|
+
// a reference version
|
391
|
+
for (size_t i = begin; i < end; i++) {
|
392
|
+
storage_idx_t nodeId = hnsw.neighbors[i];
|
393
|
+
if (nodeId < 0)
|
394
|
+
break;
|
395
|
+
if (vt.get(nodeId))
|
396
|
+
continue;
|
397
|
+
vt.set(nodeId);
|
398
|
+
|
399
|
+
float dis = qdis(nodeId);
|
400
|
+
NodeDistFarther evE1(dis, nodeId);
|
401
|
+
|
402
|
+
if (results.size() < hnsw.efConstruction ||
|
403
|
+
results.top().d > dis) {
|
404
|
+
results.emplace(dis, nodeId);
|
405
|
+
candidates.emplace(dis, nodeId);
|
406
|
+
if (results.size() > hnsw.efConstruction) {
|
407
|
+
results.pop();
|
408
|
+
}
|
373
409
|
}
|
374
410
|
}
|
375
|
-
}
|
376
|
-
|
377
|
-
|
378
|
-
|
411
|
+
} else {
|
412
|
+
// a faster version
|
413
|
+
|
414
|
+
// the following version processes 4 neighbors at a time
|
415
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
416
|
+
const float dis) {
|
417
|
+
if (results.size() < hnsw.efConstruction ||
|
418
|
+
results.top().d > dis) {
|
419
|
+
results.emplace(dis, idx);
|
420
|
+
candidates.emplace(dis, idx);
|
421
|
+
if (results.size() > hnsw.efConstruction) {
|
422
|
+
results.pop();
|
423
|
+
}
|
424
|
+
}
|
425
|
+
};
|
379
426
|
|
380
|
-
|
381
|
-
|
382
|
-
**************************************************************/
|
427
|
+
int n_buffered = 0;
|
428
|
+
storage_idx_t buffered_ids[4];
|
383
429
|
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
430
|
+
for (size_t j = begin; j < end; j++) {
|
431
|
+
storage_idx_t nodeId = hnsw.neighbors[j];
|
432
|
+
if (nodeId < 0)
|
433
|
+
break;
|
434
|
+
if (vt.get(nodeId)) {
|
435
|
+
continue;
|
436
|
+
}
|
437
|
+
vt.set(nodeId);
|
438
|
+
|
439
|
+
buffered_ids[n_buffered] = nodeId;
|
440
|
+
n_buffered += 1;
|
441
|
+
|
442
|
+
if (n_buffered == 4) {
|
443
|
+
float dis[4];
|
444
|
+
qdis.distances_batch_4(
|
445
|
+
buffered_ids[0],
|
446
|
+
buffered_ids[1],
|
447
|
+
buffered_ids[2],
|
448
|
+
buffered_ids[3],
|
449
|
+
dis[0],
|
450
|
+
dis[1],
|
451
|
+
dis[2],
|
452
|
+
dis[3]);
|
453
|
+
|
454
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
455
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
456
|
+
}
|
393
457
|
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
nearest = v;
|
403
|
-
d_nearest = dis;
|
458
|
+
n_buffered = 0;
|
459
|
+
}
|
460
|
+
}
|
461
|
+
|
462
|
+
// process leftovers
|
463
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
464
|
+
float dis = qdis(buffered_ids[icnt]);
|
465
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
404
466
|
}
|
405
|
-
}
|
406
|
-
if (nearest == prev_nearest) {
|
407
|
-
return;
|
408
467
|
}
|
409
468
|
}
|
469
|
+
|
470
|
+
vt.advance();
|
410
471
|
}
|
411
472
|
|
412
473
|
} // namespace
|
@@ -420,7 +481,8 @@ void HNSW::add_links_starting_from(
|
|
420
481
|
float d_nearest,
|
421
482
|
int level,
|
422
483
|
omp_lock_t* locks,
|
423
|
-
VisitedTable& vt
|
484
|
+
VisitedTable& vt,
|
485
|
+
bool keep_max_size_level0) {
|
424
486
|
std::priority_queue<NodeDistCloser> link_targets;
|
425
487
|
|
426
488
|
search_neighbors_to_add(
|
@@ -429,13 +491,13 @@ void HNSW::add_links_starting_from(
|
|
429
491
|
// but we can afford only this many neighbors
|
430
492
|
int M = nb_neighbors(level);
|
431
493
|
|
432
|
-
::faiss::shrink_neighbor_list(ptdis, link_targets, M);
|
494
|
+
::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
|
433
495
|
|
434
496
|
std::vector<storage_idx_t> neighbors;
|
435
497
|
neighbors.reserve(link_targets.size());
|
436
498
|
while (!link_targets.empty()) {
|
437
499
|
storage_idx_t other_id = link_targets.top().id;
|
438
|
-
add_link(*this, ptdis, pt_id, other_id, level);
|
500
|
+
add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
|
439
501
|
neighbors.push_back(other_id);
|
440
502
|
link_targets.pop();
|
441
503
|
}
|
@@ -443,7 +505,7 @@ void HNSW::add_links_starting_from(
|
|
443
505
|
omp_unset_lock(&locks[pt_id]);
|
444
506
|
for (storage_idx_t other_id : neighbors) {
|
445
507
|
omp_set_lock(&locks[other_id]);
|
446
|
-
add_link(*this, ptdis, other_id, pt_id, level);
|
508
|
+
add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
|
447
509
|
omp_unset_lock(&locks[other_id]);
|
448
510
|
}
|
449
511
|
omp_set_lock(&locks[pt_id]);
|
@@ -458,7 +520,8 @@ void HNSW::add_with_locks(
|
|
458
520
|
int pt_level,
|
459
521
|
int pt_id,
|
460
522
|
std::vector<omp_lock_t>& locks,
|
461
|
-
VisitedTable& vt
|
523
|
+
VisitedTable& vt,
|
524
|
+
bool keep_max_size_level0) {
|
462
525
|
// greedy search on upper levels
|
463
526
|
|
464
527
|
storage_idx_t nearest;
|
@@ -487,7 +550,14 @@ void HNSW::add_with_locks(
|
|
487
550
|
|
488
551
|
for (; level >= 0; level--) {
|
489
552
|
add_links_starting_from(
|
490
|
-
ptdis,
|
553
|
+
ptdis,
|
554
|
+
pt_id,
|
555
|
+
nearest,
|
556
|
+
d_nearest,
|
557
|
+
level,
|
558
|
+
locks.data(),
|
559
|
+
vt,
|
560
|
+
keep_max_size_level0);
|
491
561
|
}
|
492
562
|
|
493
563
|
omp_unset_lock(&locks[pt_id]);
|
@@ -502,24 +572,20 @@ void HNSW::add_with_locks(
|
|
502
572
|
* Searching
|
503
573
|
**************************************************************/
|
504
574
|
|
505
|
-
namespace {
|
506
|
-
|
507
575
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
508
576
|
using Node = HNSW::Node;
|
577
|
+
using C = HNSW::C;
|
509
578
|
/** Do a BFS on the candidates list */
|
510
|
-
|
511
579
|
int search_from_candidates(
|
512
580
|
const HNSW& hnsw,
|
513
581
|
DistanceComputer& qdis,
|
514
|
-
|
515
|
-
idx_t* I,
|
516
|
-
float* D,
|
582
|
+
ResultHandler<C>& res,
|
517
583
|
MinimaxHeap& candidates,
|
518
584
|
VisitedTable& vt,
|
519
585
|
HNSWStats& stats,
|
520
586
|
int level,
|
521
|
-
int nres_in
|
522
|
-
const SearchParametersHNSW* params
|
587
|
+
int nres_in,
|
588
|
+
const SearchParametersHNSW* params) {
|
523
589
|
int nres = nres_in;
|
524
590
|
int ndis = 0;
|
525
591
|
|
@@ -529,15 +595,16 @@ int search_from_candidates(
|
|
529
595
|
int efSearch = params ? params->efSearch : hnsw.efSearch;
|
530
596
|
const IDSelector* sel = params ? params->sel : nullptr;
|
531
597
|
|
598
|
+
C::T threshold = res.threshold;
|
532
599
|
for (int i = 0; i < candidates.size(); i++) {
|
533
600
|
idx_t v1 = candidates.ids[i];
|
534
601
|
float d = candidates.dis[i];
|
535
602
|
FAISS_ASSERT(v1 >= 0);
|
536
603
|
if (!sel || sel->is_member(v1)) {
|
537
|
-
if (
|
538
|
-
|
539
|
-
|
540
|
-
|
604
|
+
if (d < threshold) {
|
605
|
+
if (res.add_result(d, v1)) {
|
606
|
+
threshold = res.threshold;
|
607
|
+
}
|
541
608
|
}
|
542
609
|
}
|
543
610
|
vt.set(v1);
|
@@ -563,24 +630,70 @@ int search_from_candidates(
|
|
563
630
|
size_t begin, end;
|
564
631
|
hnsw.neighbor_range(v0, level, &begin, &end);
|
565
632
|
|
633
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
634
|
+
// the following version processes 4 neighbors at a time
|
635
|
+
size_t jmax = begin;
|
566
636
|
for (size_t j = begin; j < end; j++) {
|
567
637
|
int v1 = hnsw.neighbors[j];
|
568
638
|
if (v1 < 0)
|
569
639
|
break;
|
570
|
-
|
571
|
-
|
640
|
+
|
641
|
+
prefetch_L2(vt.visited.data() + v1);
|
642
|
+
jmax += 1;
|
643
|
+
}
|
644
|
+
|
645
|
+
int counter = 0;
|
646
|
+
size_t saved_j[4];
|
647
|
+
|
648
|
+
threshold = res.threshold;
|
649
|
+
|
650
|
+
auto add_to_heap = [&](const size_t idx, const float dis) {
|
651
|
+
if (!sel || sel->is_member(idx)) {
|
652
|
+
if (dis < threshold) {
|
653
|
+
if (res.add_result(dis, idx)) {
|
654
|
+
threshold = res.threshold;
|
655
|
+
nres += 1;
|
656
|
+
}
|
657
|
+
}
|
572
658
|
}
|
659
|
+
candidates.push(idx, dis);
|
660
|
+
};
|
661
|
+
|
662
|
+
for (size_t j = begin; j < jmax; j++) {
|
663
|
+
int v1 = hnsw.neighbors[j];
|
664
|
+
|
665
|
+
bool vget = vt.get(v1);
|
573
666
|
vt.set(v1);
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
667
|
+
saved_j[counter] = v1;
|
668
|
+
counter += vget ? 0 : 1;
|
669
|
+
|
670
|
+
if (counter == 4) {
|
671
|
+
float dis[4];
|
672
|
+
qdis.distances_batch_4(
|
673
|
+
saved_j[0],
|
674
|
+
saved_j[1],
|
675
|
+
saved_j[2],
|
676
|
+
saved_j[3],
|
677
|
+
dis[0],
|
678
|
+
dis[1],
|
679
|
+
dis[2],
|
680
|
+
dis[3]);
|
681
|
+
|
682
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
683
|
+
add_to_heap(saved_j[id4], dis[id4]);
|
581
684
|
}
|
685
|
+
|
686
|
+
ndis += 4;
|
687
|
+
|
688
|
+
counter = 0;
|
582
689
|
}
|
583
|
-
|
690
|
+
}
|
691
|
+
|
692
|
+
for (size_t icnt = 0; icnt < counter; icnt++) {
|
693
|
+
float dis = qdis(saved_j[icnt]);
|
694
|
+
add_to_heap(saved_j[icnt], dis);
|
695
|
+
|
696
|
+
ndis += 1;
|
584
697
|
}
|
585
698
|
|
586
699
|
nstep++;
|
@@ -594,7 +707,8 @@ int search_from_candidates(
|
|
594
707
|
if (candidates.size() == 0) {
|
595
708
|
stats.n2++;
|
596
709
|
}
|
597
|
-
stats.
|
710
|
+
stats.ndis += ndis;
|
711
|
+
stats.nhops += nstep;
|
598
712
|
}
|
599
713
|
|
600
714
|
return nres;
|
@@ -630,151 +744,241 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
630
744
|
size_t begin, end;
|
631
745
|
hnsw.neighbor_range(v0, 0, &begin, &end);
|
632
746
|
|
633
|
-
|
747
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
748
|
+
// the following version processes 4 neighbors at a time
|
749
|
+
size_t jmax = begin;
|
750
|
+
for (size_t j = begin; j < end; j++) {
|
634
751
|
int v1 = hnsw.neighbors[j];
|
635
|
-
|
636
|
-
if (v1 < 0) {
|
752
|
+
if (v1 < 0)
|
637
753
|
break;
|
638
|
-
}
|
639
|
-
if (vt->get(v1)) {
|
640
|
-
continue;
|
641
|
-
}
|
642
754
|
|
643
|
-
vt->
|
755
|
+
prefetch_L2(vt->visited.data() + v1);
|
756
|
+
jmax += 1;
|
757
|
+
}
|
644
758
|
|
645
|
-
|
646
|
-
|
759
|
+
int counter = 0;
|
760
|
+
size_t saved_j[4];
|
647
761
|
|
648
|
-
|
649
|
-
|
650
|
-
top_candidates.
|
762
|
+
auto add_to_heap = [&](const size_t idx, const float dis) {
|
763
|
+
if (top_candidates.top().first > dis ||
|
764
|
+
top_candidates.size() < ef) {
|
765
|
+
candidates.emplace(dis, idx);
|
766
|
+
top_candidates.emplace(dis, idx);
|
651
767
|
|
652
768
|
if (top_candidates.size() > ef) {
|
653
769
|
top_candidates.pop();
|
654
770
|
}
|
655
771
|
}
|
772
|
+
};
|
773
|
+
|
774
|
+
for (size_t j = begin; j < jmax; j++) {
|
775
|
+
int v1 = hnsw.neighbors[j];
|
776
|
+
|
777
|
+
bool vget = vt->get(v1);
|
778
|
+
vt->set(v1);
|
779
|
+
saved_j[counter] = v1;
|
780
|
+
counter += vget ? 0 : 1;
|
781
|
+
|
782
|
+
if (counter == 4) {
|
783
|
+
float dis[4];
|
784
|
+
qdis.distances_batch_4(
|
785
|
+
saved_j[0],
|
786
|
+
saved_j[1],
|
787
|
+
saved_j[2],
|
788
|
+
saved_j[3],
|
789
|
+
dis[0],
|
790
|
+
dis[1],
|
791
|
+
dis[2],
|
792
|
+
dis[3]);
|
793
|
+
|
794
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
795
|
+
add_to_heap(saved_j[id4], dis[id4]);
|
796
|
+
}
|
797
|
+
|
798
|
+
ndis += 4;
|
799
|
+
|
800
|
+
counter = 0;
|
801
|
+
}
|
802
|
+
}
|
803
|
+
|
804
|
+
for (size_t icnt = 0; icnt < counter; icnt++) {
|
805
|
+
float dis = qdis(saved_j[icnt]);
|
806
|
+
add_to_heap(saved_j[icnt], dis);
|
807
|
+
|
808
|
+
ndis += 1;
|
656
809
|
}
|
810
|
+
|
811
|
+
stats.nhops += 1;
|
657
812
|
}
|
658
813
|
|
659
814
|
++stats.n1;
|
660
815
|
if (candidates.size() == 0) {
|
661
816
|
++stats.n2;
|
662
817
|
}
|
663
|
-
stats.
|
818
|
+
stats.ndis += ndis;
|
664
819
|
|
665
820
|
return top_candidates;
|
666
821
|
}
|
667
822
|
|
668
|
-
|
669
|
-
|
670
|
-
|
823
|
+
/// greedily update a nearest vector at a given level
|
824
|
+
HNSWStats greedy_update_nearest(
|
825
|
+
const HNSW& hnsw,
|
671
826
|
DistanceComputer& qdis,
|
672
|
-
int
|
673
|
-
|
674
|
-
float
|
675
|
-
VisitedTable& vt,
|
676
|
-
const SearchParametersHNSW* params) const {
|
827
|
+
int level,
|
828
|
+
storage_idx_t& nearest,
|
829
|
+
float& d_nearest) {
|
677
830
|
HNSWStats stats;
|
678
|
-
if (entry_point == -1) {
|
679
|
-
return stats;
|
680
|
-
}
|
681
|
-
if (upper_beam == 1) {
|
682
|
-
// greedy search on upper levels
|
683
|
-
storage_idx_t nearest = entry_point;
|
684
|
-
float d_nearest = qdis(nearest);
|
685
831
|
|
686
|
-
|
687
|
-
|
688
|
-
}
|
832
|
+
for (;;) {
|
833
|
+
storage_idx_t prev_nearest = nearest;
|
689
834
|
|
690
|
-
|
691
|
-
|
692
|
-
MinimaxHeap candidates(ef);
|
835
|
+
size_t begin, end;
|
836
|
+
hnsw.neighbor_range(nearest, level, &begin, &end);
|
693
837
|
|
694
|
-
|
838
|
+
size_t ndis = 0;
|
695
839
|
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
qdis,
|
704
|
-
ef,
|
705
|
-
&vt,
|
706
|
-
stats);
|
707
|
-
|
708
|
-
while (top_candidates.size() > k) {
|
709
|
-
top_candidates.pop();
|
840
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
841
|
+
// the following version processes 4 neighbors at a time
|
842
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
843
|
+
const float dis) {
|
844
|
+
if (dis < d_nearest) {
|
845
|
+
nearest = idx;
|
846
|
+
d_nearest = dis;
|
710
847
|
}
|
848
|
+
};
|
849
|
+
|
850
|
+
int n_buffered = 0;
|
851
|
+
storage_idx_t buffered_ids[4];
|
852
|
+
|
853
|
+
for (size_t j = begin; j < end; j++) {
|
854
|
+
storage_idx_t v = hnsw.neighbors[j];
|
855
|
+
if (v < 0)
|
856
|
+
break;
|
857
|
+
ndis += 1;
|
858
|
+
|
859
|
+
buffered_ids[n_buffered] = v;
|
860
|
+
n_buffered += 1;
|
861
|
+
|
862
|
+
if (n_buffered == 4) {
|
863
|
+
float dis[4];
|
864
|
+
qdis.distances_batch_4(
|
865
|
+
buffered_ids[0],
|
866
|
+
buffered_ids[1],
|
867
|
+
buffered_ids[2],
|
868
|
+
buffered_ids[3],
|
869
|
+
dis[0],
|
870
|
+
dis[1],
|
871
|
+
dis[2],
|
872
|
+
dis[3]);
|
873
|
+
|
874
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
875
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
876
|
+
}
|
711
877
|
|
712
|
-
|
713
|
-
while (!top_candidates.empty()) {
|
714
|
-
float d;
|
715
|
-
storage_idx_t label;
|
716
|
-
std::tie(d, label) = top_candidates.top();
|
717
|
-
faiss::maxheap_push(++nres, D, I, d, label);
|
718
|
-
top_candidates.pop();
|
878
|
+
n_buffered = 0;
|
719
879
|
}
|
720
880
|
}
|
721
881
|
|
722
|
-
|
882
|
+
// process leftovers
|
883
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
884
|
+
float dis = qdis(buffered_ids[icnt]);
|
885
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
886
|
+
}
|
723
887
|
|
724
|
-
|
725
|
-
|
726
|
-
|
888
|
+
// update stats
|
889
|
+
stats.ndis += ndis;
|
890
|
+
stats.nhops += 1;
|
727
891
|
|
728
|
-
|
729
|
-
|
892
|
+
if (nearest == prev_nearest) {
|
893
|
+
return stats;
|
894
|
+
}
|
895
|
+
}
|
896
|
+
}
|
730
897
|
|
731
|
-
|
732
|
-
|
733
|
-
|
898
|
+
namespace {
|
899
|
+
using MinimaxHeap = HNSW::MinimaxHeap;
|
900
|
+
using Node = HNSW::Node;
|
901
|
+
using C = HNSW::C;
|
734
902
|
|
735
|
-
|
736
|
-
|
903
|
+
// just used as a lower bound for the minmaxheap, but it is set for heap search
|
904
|
+
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
905
|
+
using RH = HeapBlockResultHandler<C>;
|
906
|
+
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
|
907
|
+
return hres->k;
|
908
|
+
}
|
909
|
+
return 1;
|
910
|
+
}
|
737
911
|
|
738
|
-
|
912
|
+
} // namespace
|
739
913
|
|
740
|
-
|
741
|
-
|
742
|
-
|
914
|
+
HNSWStats HNSW::search(
|
915
|
+
DistanceComputer& qdis,
|
916
|
+
ResultHandler<C>& res,
|
917
|
+
VisitedTable& vt,
|
918
|
+
const SearchParametersHNSW* params) const {
|
919
|
+
HNSWStats stats;
|
920
|
+
if (entry_point == -1) {
|
921
|
+
return stats;
|
922
|
+
}
|
923
|
+
int k = extract_k_from_ResultHandler(res);
|
743
924
|
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
925
|
+
bool bounded_queue =
|
926
|
+
params ? params->bounded_queue : this->search_bounded_queue;
|
927
|
+
|
928
|
+
// greedy search on upper levels
|
929
|
+
storage_idx_t nearest = entry_point;
|
930
|
+
float d_nearest = qdis(nearest);
|
931
|
+
|
932
|
+
for (int level = max_level; level >= 1; level--) {
|
933
|
+
HNSWStats local_stats =
|
934
|
+
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
935
|
+
stats.combine(local_stats);
|
936
|
+
}
|
937
|
+
|
938
|
+
int ef = std::max(params ? params->efSearch : efSearch, k);
|
939
|
+
if (bounded_queue) { // this is the most common branch
|
940
|
+
MinimaxHeap candidates(ef);
|
941
|
+
|
942
|
+
candidates.push(nearest, d_nearest);
|
943
|
+
|
944
|
+
search_from_candidates(
|
945
|
+
*this, qdis, res, candidates, vt, stats, 0, 0, params);
|
946
|
+
} else {
|
947
|
+
std::priority_queue<Node> top_candidates =
|
948
|
+
search_from_candidate_unbounded(
|
949
|
+
*this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
|
950
|
+
|
951
|
+
while (top_candidates.size() > k) {
|
952
|
+
top_candidates.pop();
|
953
|
+
}
|
954
|
+
|
955
|
+
while (!top_candidates.empty()) {
|
956
|
+
float d;
|
957
|
+
storage_idx_t label;
|
958
|
+
std::tie(d, label) = top_candidates.top();
|
959
|
+
res.add_result(d, label);
|
960
|
+
top_candidates.pop();
|
760
961
|
}
|
761
962
|
}
|
762
963
|
|
964
|
+
vt.advance();
|
965
|
+
|
763
966
|
return stats;
|
764
967
|
}
|
765
968
|
|
766
969
|
void HNSW::search_level_0(
|
767
970
|
DistanceComputer& qdis,
|
768
|
-
|
769
|
-
idx_t* idxi,
|
770
|
-
float* simi,
|
971
|
+
ResultHandler<C>& res,
|
771
972
|
idx_t nprobe,
|
772
973
|
const storage_idx_t* nearest_i,
|
773
974
|
const float* nearest_d,
|
774
975
|
int search_type,
|
775
976
|
HNSWStats& search_stats,
|
776
|
-
VisitedTable& vt
|
977
|
+
VisitedTable& vt,
|
978
|
+
const SearchParametersHNSW* params) const {
|
777
979
|
const HNSW& hnsw = *this;
|
980
|
+
auto efSearch = params ? params->efSearch : hnsw.efSearch;
|
981
|
+
int k = extract_k_from_ResultHandler(res);
|
778
982
|
|
779
983
|
if (search_type == 1) {
|
780
984
|
int nres = 0;
|
@@ -788,7 +992,7 @@ void HNSW::search_level_0(
|
|
788
992
|
if (vt.get(cj))
|
789
993
|
continue;
|
790
994
|
|
791
|
-
int candidates_size = std::max(
|
995
|
+
int candidates_size = std::max(efSearch, k);
|
792
996
|
MinimaxHeap candidates(candidates_size);
|
793
997
|
|
794
998
|
candidates.push(cj, nearest_d[j]);
|
@@ -796,17 +1000,17 @@ void HNSW::search_level_0(
|
|
796
1000
|
nres = search_from_candidates(
|
797
1001
|
hnsw,
|
798
1002
|
qdis,
|
799
|
-
|
800
|
-
idxi,
|
801
|
-
simi,
|
1003
|
+
res,
|
802
1004
|
candidates,
|
803
1005
|
vt,
|
804
1006
|
search_stats,
|
805
1007
|
0,
|
806
|
-
nres
|
1008
|
+
nres,
|
1009
|
+
params);
|
1010
|
+
nres = std::min(nres, candidates_size);
|
807
1011
|
}
|
808
1012
|
} else if (search_type == 2) {
|
809
|
-
int candidates_size = std::max(
|
1013
|
+
int candidates_size = std::max(efSearch, int(k));
|
810
1014
|
candidates_size = std::max(candidates_size, int(nprobe));
|
811
1015
|
|
812
1016
|
MinimaxHeap candidates(candidates_size);
|
@@ -819,10 +1023,43 @@ void HNSW::search_level_0(
|
|
819
1023
|
}
|
820
1024
|
|
821
1025
|
search_from_candidates(
|
822
|
-
hnsw, qdis,
|
1026
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
|
823
1027
|
}
|
824
1028
|
}
|
825
1029
|
|
1030
|
+
void HNSW::permute_entries(const idx_t* map) {
|
1031
|
+
// remap levels
|
1032
|
+
storage_idx_t ntotal = levels.size();
|
1033
|
+
std::vector<storage_idx_t> imap(ntotal); // inverse mapping
|
1034
|
+
// map: new index -> old index
|
1035
|
+
// imap: old index -> new index
|
1036
|
+
for (int i = 0; i < ntotal; i++) {
|
1037
|
+
assert(map[i] >= 0 && map[i] < ntotal);
|
1038
|
+
imap[map[i]] = i;
|
1039
|
+
}
|
1040
|
+
if (entry_point != -1) {
|
1041
|
+
entry_point = imap[entry_point];
|
1042
|
+
}
|
1043
|
+
std::vector<int> new_levels(ntotal);
|
1044
|
+
std::vector<size_t> new_offsets(ntotal + 1);
|
1045
|
+
std::vector<storage_idx_t> new_neighbors(neighbors.size());
|
1046
|
+
size_t no = 0;
|
1047
|
+
for (int i = 0; i < ntotal; i++) {
|
1048
|
+
storage_idx_t o = map[i]; // corresponding "old" index
|
1049
|
+
new_levels[i] = levels[o];
|
1050
|
+
for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
|
1051
|
+
storage_idx_t neigh = neighbors[j];
|
1052
|
+
new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
|
1053
|
+
}
|
1054
|
+
new_offsets[i + 1] = no;
|
1055
|
+
}
|
1056
|
+
assert(new_offsets[ntotal] == offsets[ntotal]);
|
1057
|
+
// swap everyone
|
1058
|
+
std::swap(levels, new_levels);
|
1059
|
+
std::swap(offsets, new_offsets);
|
1060
|
+
std::swap(neighbors, new_neighbors);
|
1061
|
+
}
|
1062
|
+
|
826
1063
|
/**************************************************************
|
827
1064
|
* MinimaxHeap
|
828
1065
|
**************************************************************/
|
@@ -852,17 +1089,197 @@ void HNSW::MinimaxHeap::clear() {
|
|
852
1089
|
nvalid = k = 0;
|
853
1090
|
}
|
854
1091
|
|
1092
|
+
#ifdef __AVX512F__
|
1093
|
+
|
1094
|
+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
1095
|
+
assert(k > 0);
|
1096
|
+
static_assert(
|
1097
|
+
std::is_same<storage_idx_t, int32_t>::value,
|
1098
|
+
"This code expects storage_idx_t to be int32_t");
|
1099
|
+
|
1100
|
+
int32_t min_idx = -1;
|
1101
|
+
float min_dis = std::numeric_limits<float>::infinity();
|
1102
|
+
|
1103
|
+
__m512i min_indices = _mm512_set1_epi32(-1);
|
1104
|
+
__m512 min_distances =
|
1105
|
+
_mm512_set1_ps(std::numeric_limits<float>::infinity());
|
1106
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1107
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1108
|
+
__m512i offset = _mm512_set1_epi32(16);
|
1109
|
+
|
1110
|
+
// The following loop tracks the rightmost index with the min distance.
|
1111
|
+
// -1 index values are ignored.
|
1112
|
+
const int k16 = (k / 16) * 16;
|
1113
|
+
for (size_t iii = 0; iii < k16; iii += 16) {
|
1114
|
+
__m512i indices =
|
1115
|
+
_mm512_loadu_si512((const __m512i*)(ids.data() + iii));
|
1116
|
+
__m512 distances = _mm512_loadu_ps(dis.data() + iii);
|
1117
|
+
|
1118
|
+
// This mask filters out -1 values among indices.
|
1119
|
+
__mmask16 m1mask =
|
1120
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
1121
|
+
|
1122
|
+
__mmask16 dmask =
|
1123
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
1124
|
+
__mmask16 finalmask = m1mask | dmask;
|
1125
|
+
|
1126
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
1127
|
+
finalmask, current_indices, min_indices);
|
1128
|
+
const __m512 min_distances_new =
|
1129
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
1130
|
+
|
1131
|
+
min_indices = min_indices_new;
|
1132
|
+
min_distances = min_distances_new;
|
1133
|
+
|
1134
|
+
current_indices = _mm512_add_epi32(current_indices, offset);
|
1135
|
+
}
|
1136
|
+
|
1137
|
+
// leftovers
|
1138
|
+
if (k16 != k) {
|
1139
|
+
const __mmask16 kmask = (1 << (k - k16)) - 1;
|
1140
|
+
|
1141
|
+
__m512i indices = _mm512_mask_loadu_epi32(
|
1142
|
+
_mm512_set1_epi32(-1), kmask, ids.data() + k16);
|
1143
|
+
__m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
|
1144
|
+
|
1145
|
+
// This mask filters out -1 values among indices.
|
1146
|
+
__mmask16 m1mask =
|
1147
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
1148
|
+
|
1149
|
+
__mmask16 dmask =
|
1150
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
1151
|
+
__mmask16 finalmask = m1mask | dmask;
|
1152
|
+
|
1153
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
1154
|
+
finalmask, current_indices, min_indices);
|
1155
|
+
const __m512 min_distances_new =
|
1156
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
1157
|
+
|
1158
|
+
min_indices = min_indices_new;
|
1159
|
+
min_distances = min_distances_new;
|
1160
|
+
}
|
1161
|
+
|
1162
|
+
// grab min distance
|
1163
|
+
min_dis = _mm512_reduce_min_ps(min_distances);
|
1164
|
+
// blend
|
1165
|
+
__mmask16 mindmask =
|
1166
|
+
_mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
|
1167
|
+
// pick the max one
|
1168
|
+
min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
|
1169
|
+
|
1170
|
+
if (min_idx == -1) {
|
1171
|
+
return -1;
|
1172
|
+
}
|
1173
|
+
|
1174
|
+
if (vmin_out) {
|
1175
|
+
*vmin_out = min_dis;
|
1176
|
+
}
|
1177
|
+
int ret = ids[min_idx];
|
1178
|
+
ids[min_idx] = -1;
|
1179
|
+
--nvalid;
|
1180
|
+
return ret;
|
1181
|
+
}
|
1182
|
+
|
1183
|
+
#elif __AVX2__
|
1184
|
+
|
1185
|
+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
1186
|
+
assert(k > 0);
|
1187
|
+
static_assert(
|
1188
|
+
std::is_same<storage_idx_t, int32_t>::value,
|
1189
|
+
"This code expects storage_idx_t to be int32_t");
|
1190
|
+
|
1191
|
+
int32_t min_idx = -1;
|
1192
|
+
float min_dis = std::numeric_limits<float>::infinity();
|
1193
|
+
|
1194
|
+
size_t iii = 0;
|
1195
|
+
|
1196
|
+
__m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
|
1197
|
+
__m256 min_distances =
|
1198
|
+
_mm256_set1_ps(std::numeric_limits<float>::infinity());
|
1199
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
1200
|
+
__m256i offset = _mm256_set1_epi32(8);
|
1201
|
+
|
1202
|
+
// The baseline version is available in non-AVX2 branch.
|
1203
|
+
|
1204
|
+
// The following loop tracks the rightmost index with the min distance.
|
1205
|
+
// -1 index values are ignored.
|
1206
|
+
const int k8 = (k / 8) * 8;
|
1207
|
+
for (; iii < k8; iii += 8) {
|
1208
|
+
__m256i indices =
|
1209
|
+
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
|
1210
|
+
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
|
1211
|
+
|
1212
|
+
// This mask filters out -1 values among indices.
|
1213
|
+
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
|
1214
|
+
|
1215
|
+
__m256i dmask = _mm256_castps_si256(
|
1216
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
|
1217
|
+
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
|
1218
|
+
|
1219
|
+
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
|
1220
|
+
_mm256_castsi256_ps(current_indices),
|
1221
|
+
_mm256_castsi256_ps(min_indices),
|
1222
|
+
finalmask));
|
1223
|
+
|
1224
|
+
const __m256 min_distances_new =
|
1225
|
+
_mm256_blendv_ps(distances, min_distances, finalmask);
|
1226
|
+
|
1227
|
+
min_indices = min_indices_new;
|
1228
|
+
min_distances = min_distances_new;
|
1229
|
+
|
1230
|
+
current_indices = _mm256_add_epi32(current_indices, offset);
|
1231
|
+
}
|
1232
|
+
|
1233
|
+
// Vectorizing is doable, but is not practical
|
1234
|
+
int32_t vidx8[8];
|
1235
|
+
float vdis8[8];
|
1236
|
+
_mm256_storeu_ps(vdis8, min_distances);
|
1237
|
+
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
|
1238
|
+
|
1239
|
+
for (size_t j = 0; j < 8; j++) {
|
1240
|
+
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
|
1241
|
+
min_idx = vidx8[j];
|
1242
|
+
min_dis = vdis8[j];
|
1243
|
+
}
|
1244
|
+
}
|
1245
|
+
|
1246
|
+
// process last values. Vectorizing is doable, but is not practical
|
1247
|
+
for (; iii < k; iii++) {
|
1248
|
+
if (ids[iii] != -1 && dis[iii] <= min_dis) {
|
1249
|
+
min_dis = dis[iii];
|
1250
|
+
min_idx = iii;
|
1251
|
+
}
|
1252
|
+
}
|
1253
|
+
|
1254
|
+
if (min_idx == -1) {
|
1255
|
+
return -1;
|
1256
|
+
}
|
1257
|
+
|
1258
|
+
if (vmin_out) {
|
1259
|
+
*vmin_out = min_dis;
|
1260
|
+
}
|
1261
|
+
int ret = ids[min_idx];
|
1262
|
+
ids[min_idx] = -1;
|
1263
|
+
--nvalid;
|
1264
|
+
return ret;
|
1265
|
+
}
|
1266
|
+
|
1267
|
+
#else
|
1268
|
+
|
1269
|
+
// baseline non-vectorized version
|
855
1270
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
856
1271
|
assert(k > 0);
|
857
1272
|
// returns min. This is an O(n) operation
|
858
1273
|
int i = k - 1;
|
859
1274
|
while (i >= 0) {
|
860
|
-
if (ids[i] != -1)
|
1275
|
+
if (ids[i] != -1) {
|
861
1276
|
break;
|
1277
|
+
}
|
862
1278
|
i--;
|
863
1279
|
}
|
864
|
-
if (i == -1)
|
1280
|
+
if (i == -1) {
|
865
1281
|
return -1;
|
1282
|
+
}
|
866
1283
|
int imin = i;
|
867
1284
|
float vmin = dis[i];
|
868
1285
|
i--;
|
@@ -873,14 +1290,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
873
1290
|
}
|
874
1291
|
i--;
|
875
1292
|
}
|
876
|
-
if (vmin_out)
|
1293
|
+
if (vmin_out) {
|
877
1294
|
*vmin_out = vmin;
|
1295
|
+
}
|
878
1296
|
int ret = ids[imin];
|
879
1297
|
ids[imin] = -1;
|
880
1298
|
--nvalid;
|
881
1299
|
|
882
1300
|
return ret;
|
883
1301
|
}
|
1302
|
+
#endif
|
884
1303
|
|
885
1304
|
int HNSW::MinimaxHeap::count_below(float thresh) {
|
886
1305
|
int n_below = 0;
|