faiss 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
@@ -0,0 +1,95 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#ifndef FAISS_ASSERT_INCLUDED
|
11
|
+
#define FAISS_ASSERT_INCLUDED
|
12
|
+
|
13
|
+
#include <faiss/impl/FaissException.h>
|
14
|
+
#include <cstdlib>
|
15
|
+
#include <cstdio>
|
16
|
+
#include <string>
|
17
|
+
|
18
|
+
///
|
19
|
+
/// Assertions
|
20
|
+
///
|
21
|
+
|
22
|
+
#define FAISS_ASSERT(X) \
|
23
|
+
do { \
|
24
|
+
if (! (X)) { \
|
25
|
+
fprintf(stderr, "Faiss assertion '%s' failed in %s " \
|
26
|
+
"at %s:%d\n", \
|
27
|
+
#X, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
28
|
+
abort(); \
|
29
|
+
} \
|
30
|
+
} while (false)
|
31
|
+
|
32
|
+
#define FAISS_ASSERT_MSG(X, MSG) \
|
33
|
+
do { \
|
34
|
+
if (! (X)) { \
|
35
|
+
fprintf(stderr, "Faiss assertion '%s' failed in %s " \
|
36
|
+
"at %s:%d; details: " MSG "\n", \
|
37
|
+
#X, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
38
|
+
abort(); \
|
39
|
+
} \
|
40
|
+
} while (false)
|
41
|
+
|
42
|
+
#define FAISS_ASSERT_FMT(X, FMT, ...) \
|
43
|
+
do { \
|
44
|
+
if (! (X)) { \
|
45
|
+
fprintf(stderr, "Faiss assertion '%s' failed in %s " \
|
46
|
+
"at %s:%d; details: " FMT "\n", \
|
47
|
+
#X, __PRETTY_FUNCTION__, __FILE__, __LINE__, __VA_ARGS__); \
|
48
|
+
abort(); \
|
49
|
+
} \
|
50
|
+
} while (false)
|
51
|
+
|
52
|
+
///
|
53
|
+
/// Exceptions for returning user errors
|
54
|
+
///
|
55
|
+
|
56
|
+
#define FAISS_THROW_MSG(MSG) \
|
57
|
+
do { \
|
58
|
+
throw faiss::FaissException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
59
|
+
} while (false)
|
60
|
+
|
61
|
+
#define FAISS_THROW_FMT(FMT, ...) \
|
62
|
+
do { \
|
63
|
+
std::string __s; \
|
64
|
+
int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \
|
65
|
+
__s.resize(__size + 1); \
|
66
|
+
snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \
|
67
|
+
throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
68
|
+
} while (false)
|
69
|
+
|
70
|
+
///
|
71
|
+
/// Exceptions thrown upon a conditional failure
|
72
|
+
///
|
73
|
+
|
74
|
+
#define FAISS_THROW_IF_NOT(X) \
|
75
|
+
do { \
|
76
|
+
if (!(X)) { \
|
77
|
+
FAISS_THROW_FMT("Error: '%s' failed", #X); \
|
78
|
+
} \
|
79
|
+
} while (false)
|
80
|
+
|
81
|
+
#define FAISS_THROW_IF_NOT_MSG(X, MSG) \
|
82
|
+
do { \
|
83
|
+
if (!(X)) { \
|
84
|
+
FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \
|
85
|
+
} \
|
86
|
+
} while (false)
|
87
|
+
|
88
|
+
#define FAISS_THROW_IF_NOT_FMT(X, FMT, ...) \
|
89
|
+
do { \
|
90
|
+
if (!(X)) { \
|
91
|
+
FAISS_THROW_FMT("Error: '%s' failed: " FMT, #X, __VA_ARGS__); \
|
92
|
+
} \
|
93
|
+
} while (false)
|
94
|
+
|
95
|
+
#endif
|
@@ -0,0 +1,66 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/impl/FaissException.h>
|
11
|
+
#include <sstream>
|
12
|
+
|
13
|
+
namespace faiss {
|
14
|
+
|
15
|
+
FaissException::FaissException(const std::string& m)
|
16
|
+
: msg(m) {
|
17
|
+
}
|
18
|
+
|
19
|
+
FaissException::FaissException(const std::string& m,
|
20
|
+
const char* funcName,
|
21
|
+
const char* file,
|
22
|
+
int line) {
|
23
|
+
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s",
|
24
|
+
funcName, file, line, m.c_str());
|
25
|
+
msg.resize(size + 1);
|
26
|
+
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s",
|
27
|
+
funcName, file, line, m.c_str());
|
28
|
+
}
|
29
|
+
|
30
|
+
const char*
|
31
|
+
FaissException::what() const noexcept {
|
32
|
+
return msg.c_str();
|
33
|
+
}
|
34
|
+
|
35
|
+
void handleExceptions(
|
36
|
+
std::vector<std::pair<int, std::exception_ptr>>& exceptions) {
|
37
|
+
if (exceptions.size() == 1) {
|
38
|
+
// throw the single received exception directly
|
39
|
+
std::rethrow_exception(exceptions.front().second);
|
40
|
+
|
41
|
+
} else if (exceptions.size() > 1) {
|
42
|
+
// multiple exceptions; aggregate them and return a single exception
|
43
|
+
std::stringstream ss;
|
44
|
+
|
45
|
+
for (auto& p : exceptions) {
|
46
|
+
try {
|
47
|
+
std::rethrow_exception(p.second);
|
48
|
+
} catch (std::exception& ex) {
|
49
|
+
if (ex.what()) {
|
50
|
+
// exception message available
|
51
|
+
ss << "Exception thrown from index " << p.first << ": "
|
52
|
+
<< ex.what() << "\n";
|
53
|
+
} else {
|
54
|
+
// No message available
|
55
|
+
ss << "Unknown exception thrown from index " << p.first << "\n";
|
56
|
+
}
|
57
|
+
} catch (...) {
|
58
|
+
ss << "Unknown exception thrown from index " << p.first << "\n";
|
59
|
+
}
|
60
|
+
}
|
61
|
+
|
62
|
+
throw FaissException(ss.str());
|
63
|
+
}
|
64
|
+
}
|
65
|
+
|
66
|
+
}
|
@@ -0,0 +1,71 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#ifndef FAISS_EXCEPTION_INCLUDED
|
11
|
+
#define FAISS_EXCEPTION_INCLUDED
|
12
|
+
|
13
|
+
#include <exception>
|
14
|
+
#include <string>
|
15
|
+
#include <vector>
|
16
|
+
#include <utility>
|
17
|
+
|
18
|
+
namespace faiss {
|
19
|
+
|
20
|
+
/// Base class for Faiss exceptions
|
21
|
+
class FaissException : public std::exception {
|
22
|
+
public:
|
23
|
+
explicit FaissException(const std::string& msg);
|
24
|
+
|
25
|
+
FaissException(const std::string& msg,
|
26
|
+
const char* funcName,
|
27
|
+
const char* file,
|
28
|
+
int line);
|
29
|
+
|
30
|
+
/// from std::exception
|
31
|
+
const char* what() const noexcept override;
|
32
|
+
|
33
|
+
std::string msg;
|
34
|
+
};
|
35
|
+
|
36
|
+
/// Handle multiple exceptions from worker threads, throwing an appropriate
|
37
|
+
/// exception that aggregates the information
|
38
|
+
/// The pair int is the thread that generated the exception
|
39
|
+
void
|
40
|
+
handleExceptions(std::vector<std::pair<int, std::exception_ptr>>& exceptions);
|
41
|
+
|
42
|
+
/** bare-bones unique_ptr
|
43
|
+
* this one deletes with delete [] */
|
44
|
+
template<class T>
|
45
|
+
struct ScopeDeleter {
|
46
|
+
const T * ptr;
|
47
|
+
explicit ScopeDeleter (const T* ptr = nullptr): ptr (ptr) {}
|
48
|
+
void release () {ptr = nullptr; }
|
49
|
+
void set (const T * ptr_in) { ptr = ptr_in; }
|
50
|
+
void swap (ScopeDeleter<T> &other) {std::swap (ptr, other.ptr); }
|
51
|
+
~ScopeDeleter () {
|
52
|
+
delete [] ptr;
|
53
|
+
}
|
54
|
+
};
|
55
|
+
|
56
|
+
/** same but deletes with the simple delete (least common case) */
|
57
|
+
template<class T>
|
58
|
+
struct ScopeDeleter1 {
|
59
|
+
const T * ptr;
|
60
|
+
explicit ScopeDeleter1 (const T* ptr = nullptr): ptr (ptr) {}
|
61
|
+
void release () {ptr = nullptr; }
|
62
|
+
void set (const T * ptr_in) { ptr = ptr_in; }
|
63
|
+
void swap (ScopeDeleter1<T> &other) {std::swap (ptr, other.ptr); }
|
64
|
+
~ScopeDeleter1 () {
|
65
|
+
delete ptr;
|
66
|
+
}
|
67
|
+
};
|
68
|
+
|
69
|
+
}
|
70
|
+
|
71
|
+
#endif
|
@@ -0,0 +1,818 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/impl/HNSW.h>
|
11
|
+
|
12
|
+
#include <string>
|
13
|
+
|
14
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
15
|
+
|
16
|
+
namespace faiss {
|
17
|
+
|
18
|
+
using idx_t = Index::idx_t;
|
19
|
+
|
20
|
+
/**************************************************************
|
21
|
+
* HNSW structure implementation
|
22
|
+
**************************************************************/
|
23
|
+
|
24
|
+
int HNSW::nb_neighbors(int layer_no) const
|
25
|
+
{
|
26
|
+
return cum_nneighbor_per_level[layer_no + 1] -
|
27
|
+
cum_nneighbor_per_level[layer_no];
|
28
|
+
}
|
29
|
+
|
30
|
+
void HNSW::set_nb_neighbors(int level_no, int n)
|
31
|
+
{
|
32
|
+
FAISS_THROW_IF_NOT(levels.size() == 0);
|
33
|
+
int cur_n = nb_neighbors(level_no);
|
34
|
+
for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
|
35
|
+
cum_nneighbor_per_level[i] += n - cur_n;
|
36
|
+
}
|
37
|
+
}
|
38
|
+
|
39
|
+
int HNSW::cum_nb_neighbors(int layer_no) const
|
40
|
+
{
|
41
|
+
return cum_nneighbor_per_level[layer_no];
|
42
|
+
}
|
43
|
+
|
44
|
+
void HNSW::neighbor_range(idx_t no, int layer_no,
|
45
|
+
size_t * begin, size_t * end) const
|
46
|
+
{
|
47
|
+
size_t o = offsets[no];
|
48
|
+
*begin = o + cum_nb_neighbors(layer_no);
|
49
|
+
*end = o + cum_nb_neighbors(layer_no + 1);
|
50
|
+
}
|
51
|
+
|
52
|
+
|
53
|
+
|
54
|
+
HNSW::HNSW(int M) : rng(12345) {
|
55
|
+
set_default_probas(M, 1.0 / log(M));
|
56
|
+
max_level = -1;
|
57
|
+
entry_point = -1;
|
58
|
+
efSearch = 16;
|
59
|
+
efConstruction = 40;
|
60
|
+
upper_beam = 1;
|
61
|
+
offsets.push_back(0);
|
62
|
+
}
|
63
|
+
|
64
|
+
|
65
|
+
int HNSW::random_level()
|
66
|
+
{
|
67
|
+
double f = rng.rand_float();
|
68
|
+
// could be a bit faster with bissection
|
69
|
+
for (int level = 0; level < assign_probas.size(); level++) {
|
70
|
+
if (f < assign_probas[level]) {
|
71
|
+
return level;
|
72
|
+
}
|
73
|
+
f -= assign_probas[level];
|
74
|
+
}
|
75
|
+
// happens with exponentially low probability
|
76
|
+
return assign_probas.size() - 1;
|
77
|
+
}
|
78
|
+
|
79
|
+
void HNSW::set_default_probas(int M, float levelMult)
|
80
|
+
{
|
81
|
+
int nn = 0;
|
82
|
+
cum_nneighbor_per_level.push_back (0);
|
83
|
+
for (int level = 0; ;level++) {
|
84
|
+
float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
|
85
|
+
if (proba < 1e-9) break;
|
86
|
+
assign_probas.push_back(proba);
|
87
|
+
nn += level == 0 ? M * 2 : M;
|
88
|
+
cum_nneighbor_per_level.push_back (nn);
|
89
|
+
}
|
90
|
+
}
|
91
|
+
|
92
|
+
void HNSW::clear_neighbor_tables(int level)
|
93
|
+
{
|
94
|
+
for (int i = 0; i < levels.size(); i++) {
|
95
|
+
size_t begin, end;
|
96
|
+
neighbor_range(i, level, &begin, &end);
|
97
|
+
for (size_t j = begin; j < end; j++) {
|
98
|
+
neighbors[j] = -1;
|
99
|
+
}
|
100
|
+
}
|
101
|
+
}
|
102
|
+
|
103
|
+
|
104
|
+
void HNSW::reset() {
|
105
|
+
max_level = -1;
|
106
|
+
entry_point = -1;
|
107
|
+
offsets.clear();
|
108
|
+
offsets.push_back(0);
|
109
|
+
levels.clear();
|
110
|
+
neighbors.clear();
|
111
|
+
}
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
void HNSW::print_neighbor_stats(int level) const
|
116
|
+
{
|
117
|
+
FAISS_THROW_IF_NOT (level < cum_nneighbor_per_level.size());
|
118
|
+
printf("stats on level %d, max %d neighbors per vertex:\n",
|
119
|
+
level, nb_neighbors(level));
|
120
|
+
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
|
121
|
+
#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
|
122
|
+
reduction(+: tot_reciprocal) reduction(+: n_node)
|
123
|
+
for (int i = 0; i < levels.size(); i++) {
|
124
|
+
if (levels[i] > level) {
|
125
|
+
n_node++;
|
126
|
+
size_t begin, end;
|
127
|
+
neighbor_range(i, level, &begin, &end);
|
128
|
+
std::unordered_set<int> neighset;
|
129
|
+
for (size_t j = begin; j < end; j++) {
|
130
|
+
if (neighbors [j] < 0) break;
|
131
|
+
neighset.insert(neighbors[j]);
|
132
|
+
}
|
133
|
+
int n_neigh = neighset.size();
|
134
|
+
int n_common = 0;
|
135
|
+
int n_reciprocal = 0;
|
136
|
+
for (size_t j = begin; j < end; j++) {
|
137
|
+
storage_idx_t i2 = neighbors[j];
|
138
|
+
if (i2 < 0) break;
|
139
|
+
FAISS_ASSERT(i2 != i);
|
140
|
+
size_t begin2, end2;
|
141
|
+
neighbor_range(i2, level, &begin2, &end2);
|
142
|
+
for (size_t j2 = begin2; j2 < end2; j2++) {
|
143
|
+
storage_idx_t i3 = neighbors[j2];
|
144
|
+
if (i3 < 0) break;
|
145
|
+
if (i3 == i) {
|
146
|
+
n_reciprocal++;
|
147
|
+
continue;
|
148
|
+
}
|
149
|
+
if (neighset.count(i3)) {
|
150
|
+
neighset.erase(i3);
|
151
|
+
n_common++;
|
152
|
+
}
|
153
|
+
}
|
154
|
+
}
|
155
|
+
tot_neigh += n_neigh;
|
156
|
+
tot_common += n_common;
|
157
|
+
tot_reciprocal += n_reciprocal;
|
158
|
+
}
|
159
|
+
}
|
160
|
+
float normalizer = n_node;
|
161
|
+
printf(" nb of nodes at that level %ld\n", n_node);
|
162
|
+
printf(" neighbors per node: %.2f (%ld)\n",
|
163
|
+
tot_neigh / normalizer, tot_neigh);
|
164
|
+
printf(" nb of reciprocal neighbors: %.2f\n", tot_reciprocal / normalizer);
|
165
|
+
printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%ld)\n",
|
166
|
+
tot_common / normalizer, tot_common);
|
167
|
+
|
168
|
+
|
169
|
+
|
170
|
+
}
|
171
|
+
|
172
|
+
|
173
|
+
void HNSW::fill_with_random_links(size_t n)
|
174
|
+
{
|
175
|
+
int max_level = prepare_level_tab(n);
|
176
|
+
RandomGenerator rng2(456);
|
177
|
+
|
178
|
+
for (int level = max_level - 1; level >= 0; --level) {
|
179
|
+
std::vector<int> elts;
|
180
|
+
for (int i = 0; i < n; i++) {
|
181
|
+
if (levels[i] > level) {
|
182
|
+
elts.push_back(i);
|
183
|
+
}
|
184
|
+
}
|
185
|
+
printf ("linking %ld elements in level %d\n",
|
186
|
+
elts.size(), level);
|
187
|
+
|
188
|
+
if (elts.size() == 1) continue;
|
189
|
+
|
190
|
+
for (int ii = 0; ii < elts.size(); ii++) {
|
191
|
+
int i = elts[ii];
|
192
|
+
size_t begin, end;
|
193
|
+
neighbor_range(i, 0, &begin, &end);
|
194
|
+
for (size_t j = begin; j < end; j++) {
|
195
|
+
int other = 0;
|
196
|
+
do {
|
197
|
+
other = elts[rng2.rand_int(elts.size())];
|
198
|
+
} while(other == i);
|
199
|
+
|
200
|
+
neighbors[j] = other;
|
201
|
+
}
|
202
|
+
}
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
|
207
|
+
int HNSW::prepare_level_tab(size_t n, bool preset_levels)
|
208
|
+
{
|
209
|
+
size_t n0 = offsets.size() - 1;
|
210
|
+
|
211
|
+
if (preset_levels) {
|
212
|
+
FAISS_ASSERT (n0 + n == levels.size());
|
213
|
+
} else {
|
214
|
+
FAISS_ASSERT (n0 == levels.size());
|
215
|
+
for (int i = 0; i < n; i++) {
|
216
|
+
int pt_level = random_level();
|
217
|
+
levels.push_back(pt_level + 1);
|
218
|
+
}
|
219
|
+
}
|
220
|
+
|
221
|
+
int max_level = 0;
|
222
|
+
for (int i = 0; i < n; i++) {
|
223
|
+
int pt_level = levels[i + n0] - 1;
|
224
|
+
if (pt_level > max_level) max_level = pt_level;
|
225
|
+
offsets.push_back(offsets.back() +
|
226
|
+
cum_nb_neighbors(pt_level + 1));
|
227
|
+
neighbors.resize(offsets.back(), -1);
|
228
|
+
}
|
229
|
+
|
230
|
+
return max_level;
|
231
|
+
}
|
232
|
+
|
233
|
+
|
234
|
+
/** Enumerate vertices from farthest to nearest from query, keep a
|
235
|
+
* neighbor only if there is no previous neighbor that is closer to
|
236
|
+
* that vertex than the query.
|
237
|
+
*/
|
238
|
+
void HNSW::shrink_neighbor_list(
|
239
|
+
DistanceComputer& qdis,
|
240
|
+
std::priority_queue<NodeDistFarther>& input,
|
241
|
+
std::vector<NodeDistFarther>& output,
|
242
|
+
int max_size)
|
243
|
+
{
|
244
|
+
while (input.size() > 0) {
|
245
|
+
NodeDistFarther v1 = input.top();
|
246
|
+
input.pop();
|
247
|
+
float dist_v1_q = v1.d;
|
248
|
+
|
249
|
+
bool good = true;
|
250
|
+
for (NodeDistFarther v2 : output) {
|
251
|
+
float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
|
252
|
+
|
253
|
+
if (dist_v1_v2 < dist_v1_q) {
|
254
|
+
good = false;
|
255
|
+
break;
|
256
|
+
}
|
257
|
+
}
|
258
|
+
|
259
|
+
if (good) {
|
260
|
+
output.push_back(v1);
|
261
|
+
if (output.size() >= max_size) {
|
262
|
+
return;
|
263
|
+
}
|
264
|
+
}
|
265
|
+
}
|
266
|
+
}
|
267
|
+
|
268
|
+
|
269
|
+
namespace {
|
270
|
+
|
271
|
+
|
272
|
+
using storage_idx_t = HNSW::storage_idx_t;
|
273
|
+
using NodeDistCloser = HNSW::NodeDistCloser;
|
274
|
+
using NodeDistFarther = HNSW::NodeDistFarther;
|
275
|
+
|
276
|
+
|
277
|
+
/**************************************************************
|
278
|
+
* Addition subroutines
|
279
|
+
**************************************************************/
|
280
|
+
|
281
|
+
|
282
|
+
/// remove neighbors from the list to make it smaller than max_size
|
283
|
+
void shrink_neighbor_list(
|
284
|
+
DistanceComputer& qdis,
|
285
|
+
std::priority_queue<NodeDistCloser>& resultSet1,
|
286
|
+
int max_size)
|
287
|
+
{
|
288
|
+
if (resultSet1.size() < max_size) {
|
289
|
+
return;
|
290
|
+
}
|
291
|
+
std::priority_queue<NodeDistFarther> resultSet;
|
292
|
+
std::vector<NodeDistFarther> returnlist;
|
293
|
+
|
294
|
+
while (resultSet1.size() > 0) {
|
295
|
+
resultSet.emplace(resultSet1.top().d, resultSet1.top().id);
|
296
|
+
resultSet1.pop();
|
297
|
+
}
|
298
|
+
|
299
|
+
HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size);
|
300
|
+
|
301
|
+
for (NodeDistFarther curen2 : returnlist) {
|
302
|
+
resultSet1.emplace(curen2.d, curen2.id);
|
303
|
+
}
|
304
|
+
|
305
|
+
}
|
306
|
+
|
307
|
+
|
308
|
+
/// add a link between two elements, possibly shrinking the list
|
309
|
+
/// of links to make room for it.
|
310
|
+
void add_link(HNSW& hnsw,
|
311
|
+
DistanceComputer& qdis,
|
312
|
+
storage_idx_t src, storage_idx_t dest,
|
313
|
+
int level)
|
314
|
+
{
|
315
|
+
size_t begin, end;
|
316
|
+
hnsw.neighbor_range(src, level, &begin, &end);
|
317
|
+
if (hnsw.neighbors[end - 1] == -1) {
|
318
|
+
// there is enough room, find a slot to add it
|
319
|
+
size_t i = end;
|
320
|
+
while(i > begin) {
|
321
|
+
if (hnsw.neighbors[i - 1] != -1) break;
|
322
|
+
i--;
|
323
|
+
}
|
324
|
+
hnsw.neighbors[i] = dest;
|
325
|
+
return;
|
326
|
+
}
|
327
|
+
|
328
|
+
// otherwise we let them fight out which to keep
|
329
|
+
|
330
|
+
// copy to resultSet...
|
331
|
+
std::priority_queue<NodeDistCloser> resultSet;
|
332
|
+
resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
|
333
|
+
for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
|
334
|
+
storage_idx_t neigh = hnsw.neighbors[i];
|
335
|
+
resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
|
336
|
+
}
|
337
|
+
|
338
|
+
shrink_neighbor_list(qdis, resultSet, end - begin);
|
339
|
+
|
340
|
+
// ...and back
|
341
|
+
size_t i = begin;
|
342
|
+
while (resultSet.size()) {
|
343
|
+
hnsw.neighbors[i++] = resultSet.top().id;
|
344
|
+
resultSet.pop();
|
345
|
+
}
|
346
|
+
// they may have shrunk more than just by 1 element
|
347
|
+
while(i < end) {
|
348
|
+
hnsw.neighbors[i++] = -1;
|
349
|
+
}
|
350
|
+
}
|
351
|
+
|
352
|
+
/// search neighbors on a single level, starting from an entry point
|
353
|
+
void search_neighbors_to_add(
|
354
|
+
HNSW& hnsw,
|
355
|
+
DistanceComputer& qdis,
|
356
|
+
std::priority_queue<NodeDistCloser>& results,
|
357
|
+
int entry_point,
|
358
|
+
float d_entry_point,
|
359
|
+
int level,
|
360
|
+
VisitedTable &vt)
|
361
|
+
{
|
362
|
+
// top is nearest candidate
|
363
|
+
std::priority_queue<NodeDistFarther> candidates;
|
364
|
+
|
365
|
+
NodeDistFarther ev(d_entry_point, entry_point);
|
366
|
+
candidates.push(ev);
|
367
|
+
results.emplace(d_entry_point, entry_point);
|
368
|
+
vt.set(entry_point);
|
369
|
+
|
370
|
+
while (!candidates.empty()) {
|
371
|
+
// get nearest
|
372
|
+
const NodeDistFarther &currEv = candidates.top();
|
373
|
+
|
374
|
+
if (currEv.d > results.top().d) {
|
375
|
+
break;
|
376
|
+
}
|
377
|
+
int currNode = currEv.id;
|
378
|
+
candidates.pop();
|
379
|
+
|
380
|
+
// loop over neighbors
|
381
|
+
size_t begin, end;
|
382
|
+
hnsw.neighbor_range(currNode, level, &begin, &end);
|
383
|
+
for(size_t i = begin; i < end; i++) {
|
384
|
+
storage_idx_t nodeId = hnsw.neighbors[i];
|
385
|
+
if (nodeId < 0) break;
|
386
|
+
if (vt.get(nodeId)) continue;
|
387
|
+
vt.set(nodeId);
|
388
|
+
|
389
|
+
float dis = qdis(nodeId);
|
390
|
+
NodeDistFarther evE1(dis, nodeId);
|
391
|
+
|
392
|
+
if (results.size() < hnsw.efConstruction ||
|
393
|
+
results.top().d > dis) {
|
394
|
+
|
395
|
+
results.emplace(dis, nodeId);
|
396
|
+
candidates.emplace(dis, nodeId);
|
397
|
+
if (results.size() > hnsw.efConstruction) {
|
398
|
+
results.pop();
|
399
|
+
}
|
400
|
+
}
|
401
|
+
}
|
402
|
+
}
|
403
|
+
vt.advance();
|
404
|
+
}
|
405
|
+
|
406
|
+
|
407
|
+
/**************************************************************
|
408
|
+
* Searching subroutines
|
409
|
+
**************************************************************/
|
410
|
+
|
411
|
+
/// greedily update a nearest vector at a given level
|
412
|
+
void greedy_update_nearest(const HNSW& hnsw,
|
413
|
+
DistanceComputer& qdis,
|
414
|
+
int level,
|
415
|
+
storage_idx_t& nearest,
|
416
|
+
float& d_nearest)
|
417
|
+
{
|
418
|
+
for(;;) {
|
419
|
+
storage_idx_t prev_nearest = nearest;
|
420
|
+
|
421
|
+
size_t begin, end;
|
422
|
+
hnsw.neighbor_range(nearest, level, &begin, &end);
|
423
|
+
for(size_t i = begin; i < end; i++) {
|
424
|
+
storage_idx_t v = hnsw.neighbors[i];
|
425
|
+
if (v < 0) break;
|
426
|
+
float dis = qdis(v);
|
427
|
+
if (dis < d_nearest) {
|
428
|
+
nearest = v;
|
429
|
+
d_nearest = dis;
|
430
|
+
}
|
431
|
+
}
|
432
|
+
if (nearest == prev_nearest) {
|
433
|
+
return;
|
434
|
+
}
|
435
|
+
}
|
436
|
+
}
|
437
|
+
|
438
|
+
|
439
|
+
} // namespace
|
440
|
+
|
441
|
+
|
442
|
+
/// Finds neighbors and builds links with them, starting from an entry
|
443
|
+
/// point. The own neighbor list is assumed to be locked.
|
444
|
+
void HNSW::add_links_starting_from(DistanceComputer& ptdis,
|
445
|
+
storage_idx_t pt_id,
|
446
|
+
storage_idx_t nearest,
|
447
|
+
float d_nearest,
|
448
|
+
int level,
|
449
|
+
omp_lock_t *locks,
|
450
|
+
VisitedTable &vt)
|
451
|
+
{
|
452
|
+
std::priority_queue<NodeDistCloser> link_targets;
|
453
|
+
|
454
|
+
search_neighbors_to_add(*this, ptdis, link_targets, nearest, d_nearest,
|
455
|
+
level, vt);
|
456
|
+
|
457
|
+
// but we can afford only this many neighbors
|
458
|
+
int M = nb_neighbors(level);
|
459
|
+
|
460
|
+
::faiss::shrink_neighbor_list(ptdis, link_targets, M);
|
461
|
+
|
462
|
+
while (!link_targets.empty()) {
|
463
|
+
int other_id = link_targets.top().id;
|
464
|
+
|
465
|
+
omp_set_lock(&locks[other_id]);
|
466
|
+
add_link(*this, ptdis, other_id, pt_id, level);
|
467
|
+
omp_unset_lock(&locks[other_id]);
|
468
|
+
|
469
|
+
add_link(*this, ptdis, pt_id, other_id, level);
|
470
|
+
|
471
|
+
link_targets.pop();
|
472
|
+
}
|
473
|
+
}
|
474
|
+
|
475
|
+
|
476
|
+
/**************************************************************
|
477
|
+
* Building, parallel
|
478
|
+
**************************************************************/
|
479
|
+
|
480
|
+
void HNSW::add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
|
481
|
+
std::vector<omp_lock_t>& locks,
|
482
|
+
VisitedTable& vt)
|
483
|
+
{
|
484
|
+
// greedy search on upper levels
|
485
|
+
|
486
|
+
storage_idx_t nearest;
|
487
|
+
#pragma omp critical
|
488
|
+
{
|
489
|
+
nearest = entry_point;
|
490
|
+
|
491
|
+
if (nearest == -1) {
|
492
|
+
max_level = pt_level;
|
493
|
+
entry_point = pt_id;
|
494
|
+
}
|
495
|
+
}
|
496
|
+
|
497
|
+
if (nearest < 0) {
|
498
|
+
return;
|
499
|
+
}
|
500
|
+
|
501
|
+
omp_set_lock(&locks[pt_id]);
|
502
|
+
|
503
|
+
int level = max_level; // level at which we start adding neighbors
|
504
|
+
float d_nearest = ptdis(nearest);
|
505
|
+
|
506
|
+
for(; level > pt_level; level--) {
|
507
|
+
greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
|
508
|
+
}
|
509
|
+
|
510
|
+
for(; level >= 0; level--) {
|
511
|
+
add_links_starting_from(ptdis, pt_id, nearest, d_nearest,
|
512
|
+
level, locks.data(), vt);
|
513
|
+
}
|
514
|
+
|
515
|
+
omp_unset_lock(&locks[pt_id]);
|
516
|
+
|
517
|
+
if (pt_level > max_level) {
|
518
|
+
max_level = pt_level;
|
519
|
+
entry_point = pt_id;
|
520
|
+
}
|
521
|
+
}
|
522
|
+
|
523
|
+
|
524
|
+
/** Do a BFS on the candidates list */
|
525
|
+
|
526
|
+
int HNSW::search_from_candidates(
|
527
|
+
DistanceComputer& qdis, int k,
|
528
|
+
idx_t *I, float *D,
|
529
|
+
MinimaxHeap& candidates,
|
530
|
+
VisitedTable& vt,
|
531
|
+
int level, int nres_in) const
|
532
|
+
{
|
533
|
+
int nres = nres_in;
|
534
|
+
int ndis = 0;
|
535
|
+
for (int i = 0; i < candidates.size(); i++) {
|
536
|
+
idx_t v1 = candidates.ids[i];
|
537
|
+
float d = candidates.dis[i];
|
538
|
+
FAISS_ASSERT(v1 >= 0);
|
539
|
+
if (nres < k) {
|
540
|
+
faiss::maxheap_push(++nres, D, I, d, v1);
|
541
|
+
} else if (d < D[0]) {
|
542
|
+
faiss::maxheap_pop(nres--, D, I);
|
543
|
+
faiss::maxheap_push(++nres, D, I, d, v1);
|
544
|
+
}
|
545
|
+
vt.set(v1);
|
546
|
+
}
|
547
|
+
|
548
|
+
bool do_dis_check = check_relative_distance;
|
549
|
+
int nstep = 0;
|
550
|
+
|
551
|
+
while (candidates.size() > 0) {
|
552
|
+
float d0 = 0;
|
553
|
+
int v0 = candidates.pop_min(&d0);
|
554
|
+
|
555
|
+
if (do_dis_check) {
|
556
|
+
// tricky stopping condition: there are more that ef
|
557
|
+
// distances that are processed already that are smaller
|
558
|
+
// than d0
|
559
|
+
|
560
|
+
int n_dis_below = candidates.count_below(d0);
|
561
|
+
if(n_dis_below >= efSearch) {
|
562
|
+
break;
|
563
|
+
}
|
564
|
+
}
|
565
|
+
|
566
|
+
size_t begin, end;
|
567
|
+
neighbor_range(v0, level, &begin, &end);
|
568
|
+
|
569
|
+
for (size_t j = begin; j < end; j++) {
|
570
|
+
int v1 = neighbors[j];
|
571
|
+
if (v1 < 0) break;
|
572
|
+
if (vt.get(v1)) {
|
573
|
+
continue;
|
574
|
+
}
|
575
|
+
vt.set(v1);
|
576
|
+
ndis++;
|
577
|
+
float d = qdis(v1);
|
578
|
+
if (nres < k) {
|
579
|
+
faiss::maxheap_push(++nres, D, I, d, v1);
|
580
|
+
} else if (d < D[0]) {
|
581
|
+
faiss::maxheap_pop(nres--, D, I);
|
582
|
+
faiss::maxheap_push(++nres, D, I, d, v1);
|
583
|
+
}
|
584
|
+
candidates.push(v1, d);
|
585
|
+
}
|
586
|
+
|
587
|
+
nstep++;
|
588
|
+
if (!do_dis_check && nstep > efSearch) {
|
589
|
+
break;
|
590
|
+
}
|
591
|
+
}
|
592
|
+
|
593
|
+
if (level == 0) {
|
594
|
+
#pragma omp critical
|
595
|
+
{
|
596
|
+
hnsw_stats.n1 ++;
|
597
|
+
if (candidates.size() == 0) {
|
598
|
+
hnsw_stats.n2 ++;
|
599
|
+
}
|
600
|
+
hnsw_stats.n3 += ndis;
|
601
|
+
}
|
602
|
+
}
|
603
|
+
|
604
|
+
return nres;
|
605
|
+
}
|
606
|
+
|
607
|
+
|
608
|
+
/**************************************************************
|
609
|
+
* Searching
|
610
|
+
**************************************************************/
|
611
|
+
|
612
|
+
std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
|
613
|
+
const Node& node,
|
614
|
+
DistanceComputer& qdis,
|
615
|
+
int ef,
|
616
|
+
VisitedTable *vt) const
|
617
|
+
{
|
618
|
+
int ndis = 0;
|
619
|
+
std::priority_queue<Node> top_candidates;
|
620
|
+
std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
|
621
|
+
|
622
|
+
top_candidates.push(node);
|
623
|
+
candidates.push(node);
|
624
|
+
|
625
|
+
vt->set(node.second);
|
626
|
+
|
627
|
+
while (!candidates.empty()) {
|
628
|
+
float d0;
|
629
|
+
storage_idx_t v0;
|
630
|
+
std::tie(d0, v0) = candidates.top();
|
631
|
+
|
632
|
+
if (d0 > top_candidates.top().first) {
|
633
|
+
break;
|
634
|
+
}
|
635
|
+
|
636
|
+
candidates.pop();
|
637
|
+
|
638
|
+
size_t begin, end;
|
639
|
+
neighbor_range(v0, 0, &begin, &end);
|
640
|
+
|
641
|
+
for (size_t j = begin; j < end; ++j) {
|
642
|
+
int v1 = neighbors[j];
|
643
|
+
|
644
|
+
if (v1 < 0) {
|
645
|
+
break;
|
646
|
+
}
|
647
|
+
if (vt->get(v1)) {
|
648
|
+
continue;
|
649
|
+
}
|
650
|
+
|
651
|
+
vt->set(v1);
|
652
|
+
|
653
|
+
float d1 = qdis(v1);
|
654
|
+
++ndis;
|
655
|
+
|
656
|
+
if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
|
657
|
+
candidates.emplace(d1, v1);
|
658
|
+
top_candidates.emplace(d1, v1);
|
659
|
+
|
660
|
+
if (top_candidates.size() > ef) {
|
661
|
+
top_candidates.pop();
|
662
|
+
}
|
663
|
+
}
|
664
|
+
}
|
665
|
+
}
|
666
|
+
|
667
|
+
#pragma omp critical
|
668
|
+
{
|
669
|
+
++hnsw_stats.n1;
|
670
|
+
if (candidates.size() == 0) {
|
671
|
+
++hnsw_stats.n2;
|
672
|
+
}
|
673
|
+
hnsw_stats.n3 += ndis;
|
674
|
+
}
|
675
|
+
|
676
|
+
return top_candidates;
|
677
|
+
}
|
678
|
+
|
679
|
+
void HNSW::search(DistanceComputer& qdis, int k,
|
680
|
+
idx_t *I, float *D,
|
681
|
+
VisitedTable& vt) const
|
682
|
+
{
|
683
|
+
if (upper_beam == 1) {
|
684
|
+
|
685
|
+
// greedy search on upper levels
|
686
|
+
storage_idx_t nearest = entry_point;
|
687
|
+
float d_nearest = qdis(nearest);
|
688
|
+
|
689
|
+
for(int level = max_level; level >= 1; level--) {
|
690
|
+
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
691
|
+
}
|
692
|
+
|
693
|
+
int ef = std::max(efSearch, k);
|
694
|
+
if (search_bounded_queue) {
|
695
|
+
MinimaxHeap candidates(ef);
|
696
|
+
|
697
|
+
candidates.push(nearest, d_nearest);
|
698
|
+
|
699
|
+
search_from_candidates(qdis, k, I, D, candidates, vt, 0);
|
700
|
+
} else {
|
701
|
+
std::priority_queue<Node> top_candidates =
|
702
|
+
search_from_candidate_unbounded(Node(d_nearest, nearest),
|
703
|
+
qdis, ef, &vt);
|
704
|
+
|
705
|
+
while (top_candidates.size() > k) {
|
706
|
+
top_candidates.pop();
|
707
|
+
}
|
708
|
+
|
709
|
+
int nres = 0;
|
710
|
+
while (!top_candidates.empty()) {
|
711
|
+
float d;
|
712
|
+
storage_idx_t label;
|
713
|
+
std::tie(d, label) = top_candidates.top();
|
714
|
+
faiss::maxheap_push(++nres, D, I, d, label);
|
715
|
+
top_candidates.pop();
|
716
|
+
}
|
717
|
+
}
|
718
|
+
|
719
|
+
vt.advance();
|
720
|
+
|
721
|
+
} else {
|
722
|
+
int candidates_size = upper_beam;
|
723
|
+
MinimaxHeap candidates(candidates_size);
|
724
|
+
|
725
|
+
std::vector<idx_t> I_to_next(candidates_size);
|
726
|
+
std::vector<float> D_to_next(candidates_size);
|
727
|
+
|
728
|
+
int nres = 1;
|
729
|
+
I_to_next[0] = entry_point;
|
730
|
+
D_to_next[0] = qdis(entry_point);
|
731
|
+
|
732
|
+
for(int level = max_level; level >= 0; level--) {
|
733
|
+
|
734
|
+
// copy I, D -> candidates
|
735
|
+
|
736
|
+
candidates.clear();
|
737
|
+
|
738
|
+
for (int i = 0; i < nres; i++) {
|
739
|
+
candidates.push(I_to_next[i], D_to_next[i]);
|
740
|
+
}
|
741
|
+
|
742
|
+
if (level == 0) {
|
743
|
+
nres = search_from_candidates(qdis, k, I, D, candidates, vt, 0);
|
744
|
+
} else {
|
745
|
+
nres = search_from_candidates(
|
746
|
+
qdis, candidates_size,
|
747
|
+
I_to_next.data(), D_to_next.data(),
|
748
|
+
candidates, vt, level
|
749
|
+
);
|
750
|
+
}
|
751
|
+
vt.advance();
|
752
|
+
}
|
753
|
+
}
|
754
|
+
}
|
755
|
+
|
756
|
+
|
757
|
+
void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
|
758
|
+
if (k == n) {
|
759
|
+
if (v >= dis[0]) return;
|
760
|
+
faiss::heap_pop<HC> (k--, dis.data(), ids.data());
|
761
|
+
--nvalid;
|
762
|
+
}
|
763
|
+
faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
|
764
|
+
++nvalid;
|
765
|
+
}
|
766
|
+
|
767
|
+
float HNSW::MinimaxHeap::max() const {
|
768
|
+
return dis[0];
|
769
|
+
}
|
770
|
+
|
771
|
+
int HNSW::MinimaxHeap::size() const {
|
772
|
+
return nvalid;
|
773
|
+
}
|
774
|
+
|
775
|
+
void HNSW::MinimaxHeap::clear() {
|
776
|
+
nvalid = k = 0;
|
777
|
+
}
|
778
|
+
|
779
|
+
int HNSW::MinimaxHeap::pop_min(float *vmin_out) {
|
780
|
+
assert(k > 0);
|
781
|
+
// returns min. This is an O(n) operation
|
782
|
+
int i = k - 1;
|
783
|
+
while (i >= 0) {
|
784
|
+
if (ids[i] != -1) break;
|
785
|
+
i--;
|
786
|
+
}
|
787
|
+
if (i == -1) return -1;
|
788
|
+
int imin = i;
|
789
|
+
float vmin = dis[i];
|
790
|
+
i--;
|
791
|
+
while(i >= 0) {
|
792
|
+
if (ids[i] != -1 && dis[i] < vmin) {
|
793
|
+
vmin = dis[i];
|
794
|
+
imin = i;
|
795
|
+
}
|
796
|
+
i--;
|
797
|
+
}
|
798
|
+
if (vmin_out) *vmin_out = vmin;
|
799
|
+
int ret = ids[imin];
|
800
|
+
ids[imin] = -1;
|
801
|
+
--nvalid;
|
802
|
+
|
803
|
+
return ret;
|
804
|
+
}
|
805
|
+
|
806
|
+
int HNSW::MinimaxHeap::count_below(float thresh) {
|
807
|
+
int n_below = 0;
|
808
|
+
for(int i = 0; i < k; i++) {
|
809
|
+
if (dis[i] < thresh) {
|
810
|
+
n_below++;
|
811
|
+
}
|
812
|
+
}
|
813
|
+
|
814
|
+
return n_below;
|
815
|
+
}
|
816
|
+
|
817
|
+
|
818
|
+
} // namespace faiss
|