faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
#include <faiss/utils/WorkerThread.h>
|
|
10
|
+
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <exception>
|
|
12
|
+
|
|
13
|
+
namespace faiss {
|
|
14
|
+
|
|
15
|
+
namespace {
|
|
16
|
+
|
|
17
|
+
// Captures any exceptions thrown by the lambda and returns them via the promise
|
|
18
|
+
void runCallback(std::function<void()>& fn,
|
|
19
|
+
std::promise<bool>& promise) {
|
|
20
|
+
try {
|
|
21
|
+
fn();
|
|
22
|
+
promise.set_value(true);
|
|
23
|
+
} catch (...) {
|
|
24
|
+
promise.set_exception(std::current_exception());
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
} // namespace
|
|
29
|
+
|
|
30
|
+
WorkerThread::WorkerThread() :
|
|
31
|
+
wantStop_(false) {
|
|
32
|
+
startThread();
|
|
33
|
+
|
|
34
|
+
// Make sure that the thread has started before continuing
|
|
35
|
+
add([](){}).get();
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
WorkerThread::~WorkerThread() {
|
|
39
|
+
stop();
|
|
40
|
+
waitForThreadExit();
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
void
|
|
44
|
+
WorkerThread::startThread() {
|
|
45
|
+
thread_ = std::thread([this](){ threadMain(); });
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
void
|
|
49
|
+
WorkerThread::stop() {
|
|
50
|
+
std::lock_guard<std::mutex> guard(mutex_);
|
|
51
|
+
|
|
52
|
+
wantStop_ = true;
|
|
53
|
+
monitor_.notify_one();
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
std::future<bool>
|
|
57
|
+
WorkerThread::add(std::function<void()> f) {
|
|
58
|
+
std::lock_guard<std::mutex> guard(mutex_);
|
|
59
|
+
|
|
60
|
+
if (wantStop_) {
|
|
61
|
+
// The timer thread has been stopped, or we want to stop; we can't
|
|
62
|
+
// schedule anything else
|
|
63
|
+
std::promise<bool> p;
|
|
64
|
+
auto fut = p.get_future();
|
|
65
|
+
|
|
66
|
+
// did not execute
|
|
67
|
+
p.set_value(false);
|
|
68
|
+
return fut;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
auto pr = std::promise<bool>();
|
|
72
|
+
auto fut = pr.get_future();
|
|
73
|
+
|
|
74
|
+
queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
|
|
75
|
+
|
|
76
|
+
// Wake up our thread
|
|
77
|
+
monitor_.notify_one();
|
|
78
|
+
return fut;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
void
|
|
82
|
+
WorkerThread::threadMain() {
|
|
83
|
+
threadLoop();
|
|
84
|
+
|
|
85
|
+
// Call all pending tasks
|
|
86
|
+
FAISS_ASSERT(wantStop_);
|
|
87
|
+
|
|
88
|
+
// flush all pending operations
|
|
89
|
+
for (auto& f : queue_) {
|
|
90
|
+
runCallback(f.first, f.second);
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
void
|
|
95
|
+
WorkerThread::threadLoop() {
|
|
96
|
+
while (true) {
|
|
97
|
+
std::pair<std::function<void()>, std::promise<bool>> data;
|
|
98
|
+
|
|
99
|
+
{
|
|
100
|
+
std::unique_lock<std::mutex> lock(mutex_);
|
|
101
|
+
|
|
102
|
+
while (!wantStop_ && queue_.empty()) {
|
|
103
|
+
monitor_.wait(lock);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if (wantStop_) {
|
|
107
|
+
return;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
data = std::move(queue_.front());
|
|
111
|
+
queue_.pop_front();
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
runCallback(data.first, data.second);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
void
|
|
119
|
+
WorkerThread::waitForThreadExit() {
|
|
120
|
+
try {
|
|
121
|
+
thread_.join();
|
|
122
|
+
} catch (...) {
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
} // namespace
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include <condition_variable>
|
|
12
|
+
#include <future>
|
|
13
|
+
#include <deque>
|
|
14
|
+
#include <thread>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
|
|
18
|
+
class WorkerThread {
|
|
19
|
+
public:
|
|
20
|
+
WorkerThread();
|
|
21
|
+
|
|
22
|
+
/// Stops and waits for the worker thread to exit, flushing all
|
|
23
|
+
/// pending lambdas
|
|
24
|
+
~WorkerThread();
|
|
25
|
+
|
|
26
|
+
/// Request that the worker thread stop itself
|
|
27
|
+
void stop();
|
|
28
|
+
|
|
29
|
+
/// Blocking waits in the current thread for the worker thread to
|
|
30
|
+
/// stop
|
|
31
|
+
void waitForThreadExit();
|
|
32
|
+
|
|
33
|
+
/// Adds a lambda to run on the worker thread; returns a future that
|
|
34
|
+
/// can be used to block on its completion.
|
|
35
|
+
/// Future status is `true` if the lambda was run in the worker
|
|
36
|
+
/// thread; `false` if it was not run, because the worker thread is
|
|
37
|
+
/// exiting or has exited.
|
|
38
|
+
std::future<bool> add(std::function<void()> f);
|
|
39
|
+
|
|
40
|
+
private:
|
|
41
|
+
void startThread();
|
|
42
|
+
void threadMain();
|
|
43
|
+
void threadLoop();
|
|
44
|
+
|
|
45
|
+
/// Thread that all queued lambdas are run on
|
|
46
|
+
std::thread thread_;
|
|
47
|
+
|
|
48
|
+
/// Mutex for the queue and exit status
|
|
49
|
+
std::mutex mutex_;
|
|
50
|
+
|
|
51
|
+
/// Monitor for the exit status and the queue
|
|
52
|
+
std::condition_variable monitor_;
|
|
53
|
+
|
|
54
|
+
/// Whether or not we want the thread to exit
|
|
55
|
+
bool wantStop_;
|
|
56
|
+
|
|
57
|
+
/// Queue of pending lambdas to call
|
|
58
|
+
std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
} // namespace
|
|
@@ -0,0 +1,765 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/utils/distances.h>
|
|
11
|
+
|
|
12
|
+
#include <cstdio>
|
|
13
|
+
#include <cassert>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <cmath>
|
|
16
|
+
|
|
17
|
+
#include <omp.h>
|
|
18
|
+
|
|
19
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
#ifndef FINTEGER
|
|
25
|
+
#define FINTEGER long
|
|
26
|
+
#endif
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
extern "C" {
|
|
30
|
+
|
|
31
|
+
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
32
|
+
|
|
33
|
+
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
|
34
|
+
n, FINTEGER *k, const float *alpha, const float *a,
|
|
35
|
+
FINTEGER *lda, const float *b, FINTEGER *
|
|
36
|
+
ldb, float *beta, float *c, FINTEGER *ldc);
|
|
37
|
+
|
|
38
|
+
/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
|
|
39
|
+
|
|
40
|
+
int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
|
|
41
|
+
float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
|
|
42
|
+
|
|
43
|
+
int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
|
|
44
|
+
const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
|
|
45
|
+
float *beta, float *y, FINTEGER *incy);
|
|
46
|
+
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
namespace faiss {
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
/***************************************************************************
|
|
55
|
+
* Matrix/vector ops
|
|
56
|
+
***************************************************************************/
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
/* Compute the inner product between a vector x and
|
|
61
|
+
a set of ny vectors y.
|
|
62
|
+
These functions are not intended to replace BLAS matrix-matrix, as they
|
|
63
|
+
would be significantly less efficient in this case. */
|
|
64
|
+
void fvec_inner_products_ny (float * ip,
|
|
65
|
+
const float * x,
|
|
66
|
+
const float * y,
|
|
67
|
+
size_t d, size_t ny)
|
|
68
|
+
{
|
|
69
|
+
// Not sure which one is fastest
|
|
70
|
+
#if 0
|
|
71
|
+
{
|
|
72
|
+
FINTEGER di = d;
|
|
73
|
+
FINTEGER nyi = ny;
|
|
74
|
+
float one = 1.0, zero = 0.0;
|
|
75
|
+
FINTEGER onei = 1;
|
|
76
|
+
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
|
77
|
+
}
|
|
78
|
+
#endif
|
|
79
|
+
for (size_t i = 0; i < ny; i++) {
|
|
80
|
+
ip[i] = fvec_inner_product (x, y, d);
|
|
81
|
+
y += d;
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
/* Compute the L2 norm of a set of nx vectors */
|
|
90
|
+
void fvec_norms_L2 (float * __restrict nr,
|
|
91
|
+
const float * __restrict x,
|
|
92
|
+
size_t d, size_t nx)
|
|
93
|
+
{
|
|
94
|
+
|
|
95
|
+
#pragma omp parallel for
|
|
96
|
+
for (size_t i = 0; i < nx; i++) {
|
|
97
|
+
nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
void fvec_norms_L2sqr (float * __restrict nr,
|
|
102
|
+
const float * __restrict x,
|
|
103
|
+
size_t d, size_t nx)
|
|
104
|
+
{
|
|
105
|
+
#pragma omp parallel for
|
|
106
|
+
for (size_t i = 0; i < nx; i++)
|
|
107
|
+
nr[i] = fvec_norm_L2sqr (x + i * d, d);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
|
113
|
+
{
|
|
114
|
+
#pragma omp parallel for
|
|
115
|
+
for (size_t i = 0; i < nx; i++) {
|
|
116
|
+
float * __restrict xi = x + i * d;
|
|
117
|
+
|
|
118
|
+
float nr = fvec_norm_L2sqr (xi, d);
|
|
119
|
+
|
|
120
|
+
if (nr > 0) {
|
|
121
|
+
size_t j;
|
|
122
|
+
const float inv_nr = 1.0 / sqrtf (nr);
|
|
123
|
+
for (j = 0; j < d; j++)
|
|
124
|
+
xi[j] *= inv_nr;
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
/***************************************************************************
|
|
141
|
+
* KNN functions
|
|
142
|
+
***************************************************************************/
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
147
|
+
static void knn_inner_product_sse (const float * x,
|
|
148
|
+
const float * y,
|
|
149
|
+
size_t d, size_t nx, size_t ny,
|
|
150
|
+
float_minheap_array_t * res)
|
|
151
|
+
{
|
|
152
|
+
size_t k = res->k;
|
|
153
|
+
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
|
154
|
+
|
|
155
|
+
check_period *= omp_get_max_threads();
|
|
156
|
+
|
|
157
|
+
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
|
158
|
+
size_t i1 = std::min(i0 + check_period, nx);
|
|
159
|
+
|
|
160
|
+
#pragma omp parallel for
|
|
161
|
+
for (size_t i = i0; i < i1; i++) {
|
|
162
|
+
const float * x_i = x + i * d;
|
|
163
|
+
const float * y_j = y;
|
|
164
|
+
|
|
165
|
+
float * __restrict simi = res->get_val(i);
|
|
166
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
|
167
|
+
|
|
168
|
+
minheap_heapify (k, simi, idxi);
|
|
169
|
+
|
|
170
|
+
for (size_t j = 0; j < ny; j++) {
|
|
171
|
+
float ip = fvec_inner_product (x_i, y_j, d);
|
|
172
|
+
|
|
173
|
+
if (ip > simi[0]) {
|
|
174
|
+
minheap_pop (k, simi, idxi);
|
|
175
|
+
minheap_push (k, simi, idxi, ip, j);
|
|
176
|
+
}
|
|
177
|
+
y_j += d;
|
|
178
|
+
}
|
|
179
|
+
minheap_reorder (k, simi, idxi);
|
|
180
|
+
}
|
|
181
|
+
InterruptCallback::check ();
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
static void knn_L2sqr_sse (
|
|
187
|
+
const float * x,
|
|
188
|
+
const float * y,
|
|
189
|
+
size_t d, size_t nx, size_t ny,
|
|
190
|
+
float_maxheap_array_t * res)
|
|
191
|
+
{
|
|
192
|
+
size_t k = res->k;
|
|
193
|
+
|
|
194
|
+
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
|
195
|
+
check_period *= omp_get_max_threads();
|
|
196
|
+
|
|
197
|
+
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
|
198
|
+
size_t i1 = std::min(i0 + check_period, nx);
|
|
199
|
+
|
|
200
|
+
#pragma omp parallel for
|
|
201
|
+
for (size_t i = i0; i < i1; i++) {
|
|
202
|
+
const float * x_i = x + i * d;
|
|
203
|
+
const float * y_j = y;
|
|
204
|
+
size_t j;
|
|
205
|
+
float * simi = res->get_val(i);
|
|
206
|
+
int64_t * idxi = res->get_ids (i);
|
|
207
|
+
|
|
208
|
+
maxheap_heapify (k, simi, idxi);
|
|
209
|
+
for (j = 0; j < ny; j++) {
|
|
210
|
+
float disij = fvec_L2sqr (x_i, y_j, d);
|
|
211
|
+
|
|
212
|
+
if (disij < simi[0]) {
|
|
213
|
+
maxheap_pop (k, simi, idxi);
|
|
214
|
+
maxheap_push (k, simi, idxi, disij, j);
|
|
215
|
+
}
|
|
216
|
+
y_j += d;
|
|
217
|
+
}
|
|
218
|
+
maxheap_reorder (k, simi, idxi);
|
|
219
|
+
}
|
|
220
|
+
InterruptCallback::check ();
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
227
|
+
static void knn_inner_product_blas (
|
|
228
|
+
const float * x,
|
|
229
|
+
const float * y,
|
|
230
|
+
size_t d, size_t nx, size_t ny,
|
|
231
|
+
float_minheap_array_t * res)
|
|
232
|
+
{
|
|
233
|
+
res->heapify ();
|
|
234
|
+
|
|
235
|
+
// BLAS does not like empty matrices
|
|
236
|
+
if (nx == 0 || ny == 0) return;
|
|
237
|
+
|
|
238
|
+
/* block sizes */
|
|
239
|
+
const size_t bs_x = 4096, bs_y = 1024;
|
|
240
|
+
// const size_t bs_x = 16, bs_y = 16;
|
|
241
|
+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
242
|
+
|
|
243
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
244
|
+
size_t i1 = i0 + bs_x;
|
|
245
|
+
if(i1 > nx) i1 = nx;
|
|
246
|
+
|
|
247
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
248
|
+
size_t j1 = j0 + bs_y;
|
|
249
|
+
if (j1 > ny) j1 = ny;
|
|
250
|
+
/* compute the actual dot products */
|
|
251
|
+
{
|
|
252
|
+
float one = 1, zero = 0;
|
|
253
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
254
|
+
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
|
255
|
+
y + j0 * d, &di,
|
|
256
|
+
x + i0 * d, &di, &zero,
|
|
257
|
+
ip_block.get(), &nyi);
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
/* collect maxima */
|
|
261
|
+
res->addn (j1 - j0, ip_block.get(), j0, i0, i1 - i0);
|
|
262
|
+
}
|
|
263
|
+
InterruptCallback::check ();
|
|
264
|
+
}
|
|
265
|
+
res->reorder ();
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
// distance correction is an operator that can be applied to transform
|
|
269
|
+
// the distances
|
|
270
|
+
template<class DistanceCorrection>
|
|
271
|
+
static void knn_L2sqr_blas (const float * x,
|
|
272
|
+
const float * y,
|
|
273
|
+
size_t d, size_t nx, size_t ny,
|
|
274
|
+
float_maxheap_array_t * res,
|
|
275
|
+
const DistanceCorrection &corr)
|
|
276
|
+
{
|
|
277
|
+
res->heapify ();
|
|
278
|
+
|
|
279
|
+
// BLAS does not like empty matrices
|
|
280
|
+
if (nx == 0 || ny == 0) return;
|
|
281
|
+
|
|
282
|
+
size_t k = res->k;
|
|
283
|
+
|
|
284
|
+
/* block sizes */
|
|
285
|
+
const size_t bs_x = 4096, bs_y = 1024;
|
|
286
|
+
// const size_t bs_x = 16, bs_y = 16;
|
|
287
|
+
float *ip_block = new float[bs_x * bs_y];
|
|
288
|
+
float *x_norms = new float[nx];
|
|
289
|
+
float *y_norms = new float[ny];
|
|
290
|
+
ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
|
|
291
|
+
|
|
292
|
+
fvec_norms_L2sqr (x_norms, x, d, nx);
|
|
293
|
+
fvec_norms_L2sqr (y_norms, y, d, ny);
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
297
|
+
size_t i1 = i0 + bs_x;
|
|
298
|
+
if(i1 > nx) i1 = nx;
|
|
299
|
+
|
|
300
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
301
|
+
size_t j1 = j0 + bs_y;
|
|
302
|
+
if (j1 > ny) j1 = ny;
|
|
303
|
+
/* compute the actual dot products */
|
|
304
|
+
{
|
|
305
|
+
float one = 1, zero = 0;
|
|
306
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
307
|
+
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
|
308
|
+
y + j0 * d, &di,
|
|
309
|
+
x + i0 * d, &di, &zero,
|
|
310
|
+
ip_block, &nyi);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
/* collect minima */
|
|
314
|
+
#pragma omp parallel for
|
|
315
|
+
for (size_t i = i0; i < i1; i++) {
|
|
316
|
+
float * __restrict simi = res->get_val(i);
|
|
317
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
|
318
|
+
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
|
319
|
+
|
|
320
|
+
for (size_t j = j0; j < j1; j++) {
|
|
321
|
+
float ip = *ip_line++;
|
|
322
|
+
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
|
323
|
+
|
|
324
|
+
// negative values can occur for identical vectors
|
|
325
|
+
// due to roundoff errors
|
|
326
|
+
if (dis < 0) dis = 0;
|
|
327
|
+
|
|
328
|
+
dis = corr (dis, i, j);
|
|
329
|
+
|
|
330
|
+
if (dis < simi[0]) {
|
|
331
|
+
maxheap_pop (k, simi, idxi);
|
|
332
|
+
maxheap_push (k, simi, idxi, dis, j);
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
}
|
|
337
|
+
InterruptCallback::check ();
|
|
338
|
+
}
|
|
339
|
+
res->reorder ();
|
|
340
|
+
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
/*******************************************************
|
|
352
|
+
* KNN driver functions
|
|
353
|
+
*******************************************************/
|
|
354
|
+
|
|
355
|
+
int distance_compute_blas_threshold = 20;
|
|
356
|
+
|
|
357
|
+
void knn_inner_product (const float * x,
|
|
358
|
+
const float * y,
|
|
359
|
+
size_t d, size_t nx, size_t ny,
|
|
360
|
+
float_minheap_array_t * res)
|
|
361
|
+
{
|
|
362
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
363
|
+
knn_inner_product_sse (x, y, d, nx, ny, res);
|
|
364
|
+
} else {
|
|
365
|
+
knn_inner_product_blas (x, y, d, nx, ny, res);
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
struct NopDistanceCorrection {
|
|
372
|
+
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
|
|
373
|
+
return dis;
|
|
374
|
+
}
|
|
375
|
+
};
|
|
376
|
+
|
|
377
|
+
void knn_L2sqr (const float * x,
|
|
378
|
+
const float * y,
|
|
379
|
+
size_t d, size_t nx, size_t ny,
|
|
380
|
+
float_maxheap_array_t * res)
|
|
381
|
+
{
|
|
382
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
383
|
+
knn_L2sqr_sse (x, y, d, nx, ny, res);
|
|
384
|
+
} else {
|
|
385
|
+
NopDistanceCorrection nop;
|
|
386
|
+
knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
struct BaseShiftDistanceCorrection {
|
|
391
|
+
const float *base_shift;
|
|
392
|
+
float operator()(float dis, size_t /*qno*/, size_t bno) const {
|
|
393
|
+
return dis - base_shift[bno];
|
|
394
|
+
}
|
|
395
|
+
};
|
|
396
|
+
|
|
397
|
+
void knn_L2sqr_base_shift (
|
|
398
|
+
const float * x,
|
|
399
|
+
const float * y,
|
|
400
|
+
size_t d, size_t nx, size_t ny,
|
|
401
|
+
float_maxheap_array_t * res,
|
|
402
|
+
const float *base_shift)
|
|
403
|
+
{
|
|
404
|
+
BaseShiftDistanceCorrection corr = {base_shift};
|
|
405
|
+
knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
/***************************************************************************
|
|
411
|
+
* compute a subset of distances
|
|
412
|
+
***************************************************************************/
|
|
413
|
+
|
|
414
|
+
/* compute the inner product between x and a subset y of ny vectors,
|
|
415
|
+
whose indices are given by idy. */
|
|
416
|
+
void fvec_inner_products_by_idx (float * __restrict ip,
|
|
417
|
+
const float * x,
|
|
418
|
+
const float * y,
|
|
419
|
+
const int64_t * __restrict ids, /* for y vecs */
|
|
420
|
+
size_t d, size_t nx, size_t ny)
|
|
421
|
+
{
|
|
422
|
+
#pragma omp parallel for
|
|
423
|
+
for (size_t j = 0; j < nx; j++) {
|
|
424
|
+
const int64_t * __restrict idsj = ids + j * ny;
|
|
425
|
+
const float * xj = x + j * d;
|
|
426
|
+
float * __restrict ipj = ip + j * ny;
|
|
427
|
+
for (size_t i = 0; i < ny; i++) {
|
|
428
|
+
if (idsj[i] < 0)
|
|
429
|
+
continue;
|
|
430
|
+
ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
|
|
431
|
+
}
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
/* compute the inner product between x and a subset y of ny vectors,
|
|
438
|
+
whose indices are given by idy. */
|
|
439
|
+
void fvec_L2sqr_by_idx (float * __restrict dis,
|
|
440
|
+
const float * x,
|
|
441
|
+
const float * y,
|
|
442
|
+
const int64_t * __restrict ids, /* ids of y vecs */
|
|
443
|
+
size_t d, size_t nx, size_t ny)
|
|
444
|
+
{
|
|
445
|
+
#pragma omp parallel for
|
|
446
|
+
for (size_t j = 0; j < nx; j++) {
|
|
447
|
+
const int64_t * __restrict idsj = ids + j * ny;
|
|
448
|
+
const float * xj = x + j * d;
|
|
449
|
+
float * __restrict disj = dis + j * ny;
|
|
450
|
+
for (size_t i = 0; i < ny; i++) {
|
|
451
|
+
if (idsj[i] < 0)
|
|
452
|
+
continue;
|
|
453
|
+
disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
void pairwise_indexed_L2sqr (
|
|
459
|
+
size_t d, size_t n,
|
|
460
|
+
const float * x, const int64_t *ix,
|
|
461
|
+
const float * y, const int64_t *iy,
|
|
462
|
+
float *dis)
|
|
463
|
+
{
|
|
464
|
+
#pragma omp parallel for
|
|
465
|
+
for (size_t j = 0; j < n; j++) {
|
|
466
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
467
|
+
dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
|
|
468
|
+
}
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
void pairwise_indexed_inner_product (
|
|
473
|
+
size_t d, size_t n,
|
|
474
|
+
const float * x, const int64_t *ix,
|
|
475
|
+
const float * y, const int64_t *iy,
|
|
476
|
+
float *dis)
|
|
477
|
+
{
|
|
478
|
+
#pragma omp parallel for
|
|
479
|
+
for (size_t j = 0; j < n; j++) {
|
|
480
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
481
|
+
dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
|
|
482
|
+
}
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
|
488
|
+
indexed by ids. May be useful for re-ranking a pre-selected vector list */
|
|
489
|
+
void knn_inner_products_by_idx (const float * x,
|
|
490
|
+
const float * y,
|
|
491
|
+
const int64_t * ids,
|
|
492
|
+
size_t d, size_t nx, size_t ny,
|
|
493
|
+
float_minheap_array_t * res)
|
|
494
|
+
{
|
|
495
|
+
size_t k = res->k;
|
|
496
|
+
|
|
497
|
+
#pragma omp parallel for
|
|
498
|
+
for (size_t i = 0; i < nx; i++) {
|
|
499
|
+
const float * x_ = x + i * d;
|
|
500
|
+
const int64_t * idsi = ids + i * ny;
|
|
501
|
+
size_t j;
|
|
502
|
+
float * __restrict simi = res->get_val(i);
|
|
503
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
|
504
|
+
minheap_heapify (k, simi, idxi);
|
|
505
|
+
|
|
506
|
+
for (j = 0; j < ny; j++) {
|
|
507
|
+
if (idsi[j] < 0) break;
|
|
508
|
+
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
|
|
509
|
+
|
|
510
|
+
if (ip > simi[0]) {
|
|
511
|
+
minheap_pop (k, simi, idxi);
|
|
512
|
+
minheap_push (k, simi, idxi, ip, idsi[j]);
|
|
513
|
+
}
|
|
514
|
+
}
|
|
515
|
+
minheap_reorder (k, simi, idxi);
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
void knn_L2sqr_by_idx (const float * x,
|
|
521
|
+
const float * y,
|
|
522
|
+
const int64_t * __restrict ids,
|
|
523
|
+
size_t d, size_t nx, size_t ny,
|
|
524
|
+
float_maxheap_array_t * res)
|
|
525
|
+
{
|
|
526
|
+
size_t k = res->k;
|
|
527
|
+
|
|
528
|
+
#pragma omp parallel for
|
|
529
|
+
for (size_t i = 0; i < nx; i++) {
|
|
530
|
+
const float * x_ = x + i * d;
|
|
531
|
+
const int64_t * __restrict idsi = ids + i * ny;
|
|
532
|
+
float * __restrict simi = res->get_val(i);
|
|
533
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
|
534
|
+
maxheap_heapify (res->k, simi, idxi);
|
|
535
|
+
for (size_t j = 0; j < ny; j++) {
|
|
536
|
+
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
|
|
537
|
+
|
|
538
|
+
if (disij < simi[0]) {
|
|
539
|
+
maxheap_pop (k, simi, idxi);
|
|
540
|
+
maxheap_push (k, simi, idxi, disij, idsi[j]);
|
|
541
|
+
}
|
|
542
|
+
}
|
|
543
|
+
maxheap_reorder (res->k, simi, idxi);
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
/***************************************************************************
|
|
553
|
+
* Range search
|
|
554
|
+
***************************************************************************/
|
|
555
|
+
|
|
556
|
+
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
|
557
|
+
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
|
558
|
+
*/
|
|
559
|
+
template <bool compute_l2>
|
|
560
|
+
static void range_search_blas (
|
|
561
|
+
const float * x,
|
|
562
|
+
const float * y,
|
|
563
|
+
size_t d, size_t nx, size_t ny,
|
|
564
|
+
float radius,
|
|
565
|
+
RangeSearchResult *result)
|
|
566
|
+
{
|
|
567
|
+
|
|
568
|
+
// BLAS does not like empty matrices
|
|
569
|
+
if (nx == 0 || ny == 0) return;
|
|
570
|
+
|
|
571
|
+
/* block sizes */
|
|
572
|
+
const size_t bs_x = 4096, bs_y = 1024;
|
|
573
|
+
// const size_t bs_x = 16, bs_y = 16;
|
|
574
|
+
float *ip_block = new float[bs_x * bs_y];
|
|
575
|
+
ScopeDeleter<float> del0(ip_block);
|
|
576
|
+
|
|
577
|
+
float *x_norms = nullptr, *y_norms = nullptr;
|
|
578
|
+
ScopeDeleter<float> del1, del2;
|
|
579
|
+
if (compute_l2) {
|
|
580
|
+
x_norms = new float[nx];
|
|
581
|
+
del1.set (x_norms);
|
|
582
|
+
fvec_norms_L2sqr (x_norms, x, d, nx);
|
|
583
|
+
|
|
584
|
+
y_norms = new float[ny];
|
|
585
|
+
del2.set (y_norms);
|
|
586
|
+
fvec_norms_L2sqr (y_norms, y, d, ny);
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
std::vector <RangeSearchPartialResult *> partial_results;
|
|
590
|
+
|
|
591
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
592
|
+
size_t j1 = j0 + bs_y;
|
|
593
|
+
if (j1 > ny) j1 = ny;
|
|
594
|
+
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
|
|
595
|
+
partial_results.push_back (pres);
|
|
596
|
+
|
|
597
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
598
|
+
size_t i1 = i0 + bs_x;
|
|
599
|
+
if(i1 > nx) i1 = nx;
|
|
600
|
+
|
|
601
|
+
/* compute the actual dot products */
|
|
602
|
+
{
|
|
603
|
+
float one = 1, zero = 0;
|
|
604
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
605
|
+
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
|
606
|
+
y + j0 * d, &di,
|
|
607
|
+
x + i0 * d, &di, &zero,
|
|
608
|
+
ip_block, &nyi);
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
for (size_t i = i0; i < i1; i++) {
|
|
613
|
+
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
|
614
|
+
|
|
615
|
+
RangeQueryResult & qres = pres->new_result (i);
|
|
616
|
+
|
|
617
|
+
for (size_t j = j0; j < j1; j++) {
|
|
618
|
+
float ip = *ip_line++;
|
|
619
|
+
if (compute_l2) {
|
|
620
|
+
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
|
621
|
+
if (dis < radius) {
|
|
622
|
+
qres.add (dis, j);
|
|
623
|
+
}
|
|
624
|
+
} else {
|
|
625
|
+
if (ip > radius) {
|
|
626
|
+
qres.add (ip, j);
|
|
627
|
+
}
|
|
628
|
+
}
|
|
629
|
+
}
|
|
630
|
+
}
|
|
631
|
+
}
|
|
632
|
+
InterruptCallback::check ();
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
RangeSearchPartialResult::merge (partial_results);
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
template <bool compute_l2>
|
|
640
|
+
static void range_search_sse (const float * x,
|
|
641
|
+
const float * y,
|
|
642
|
+
size_t d, size_t nx, size_t ny,
|
|
643
|
+
float radius,
|
|
644
|
+
RangeSearchResult *res)
|
|
645
|
+
{
|
|
646
|
+
FAISS_THROW_IF_NOT (d % 4 == 0);
|
|
647
|
+
|
|
648
|
+
#pragma omp parallel
|
|
649
|
+
{
|
|
650
|
+
RangeSearchPartialResult pres (res);
|
|
651
|
+
|
|
652
|
+
#pragma omp for
|
|
653
|
+
for (size_t i = 0; i < nx; i++) {
|
|
654
|
+
const float * x_ = x + i * d;
|
|
655
|
+
const float * y_ = y;
|
|
656
|
+
size_t j;
|
|
657
|
+
|
|
658
|
+
RangeQueryResult & qres = pres.new_result (i);
|
|
659
|
+
|
|
660
|
+
for (j = 0; j < ny; j++) {
|
|
661
|
+
if (compute_l2) {
|
|
662
|
+
float disij = fvec_L2sqr (x_, y_, d);
|
|
663
|
+
if (disij < radius) {
|
|
664
|
+
qres.add (disij, j);
|
|
665
|
+
}
|
|
666
|
+
} else {
|
|
667
|
+
float ip = fvec_inner_product (x_, y_, d);
|
|
668
|
+
if (ip > radius) {
|
|
669
|
+
qres.add (ip, j);
|
|
670
|
+
}
|
|
671
|
+
}
|
|
672
|
+
y_ += d;
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
}
|
|
676
|
+
pres.finalize ();
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
// check just at the end because the use case is typically just
|
|
680
|
+
// when the nb of queries is low.
|
|
681
|
+
InterruptCallback::check();
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
void range_search_L2sqr (
|
|
689
|
+
const float * x,
|
|
690
|
+
const float * y,
|
|
691
|
+
size_t d, size_t nx, size_t ny,
|
|
692
|
+
float radius,
|
|
693
|
+
RangeSearchResult *res)
|
|
694
|
+
{
|
|
695
|
+
|
|
696
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
697
|
+
range_search_sse<true> (x, y, d, nx, ny, radius, res);
|
|
698
|
+
} else {
|
|
699
|
+
range_search_blas<true> (x, y, d, nx, ny, radius, res);
|
|
700
|
+
}
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
void range_search_inner_product (
|
|
704
|
+
const float * x,
|
|
705
|
+
const float * y,
|
|
706
|
+
size_t d, size_t nx, size_t ny,
|
|
707
|
+
float radius,
|
|
708
|
+
RangeSearchResult *res)
|
|
709
|
+
{
|
|
710
|
+
|
|
711
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
|
712
|
+
range_search_sse<false> (x, y, d, nx, ny, radius, res);
|
|
713
|
+
} else {
|
|
714
|
+
range_search_blas<false> (x, y, d, nx, ny, radius, res);
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
void pairwise_L2sqr (int64_t d,
|
|
720
|
+
int64_t nq, const float *xq,
|
|
721
|
+
int64_t nb, const float *xb,
|
|
722
|
+
float *dis,
|
|
723
|
+
int64_t ldq, int64_t ldb, int64_t ldd)
|
|
724
|
+
{
|
|
725
|
+
if (nq == 0 || nb == 0) return;
|
|
726
|
+
if (ldq == -1) ldq = d;
|
|
727
|
+
if (ldb == -1) ldb = d;
|
|
728
|
+
if (ldd == -1) ldd = nb;
|
|
729
|
+
|
|
730
|
+
// store in beginning of distance matrix to avoid malloc
|
|
731
|
+
float *b_norms = dis;
|
|
732
|
+
|
|
733
|
+
#pragma omp parallel for
|
|
734
|
+
for (int64_t i = 0; i < nb; i++)
|
|
735
|
+
b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
|
|
736
|
+
|
|
737
|
+
#pragma omp parallel for
|
|
738
|
+
for (int64_t i = 1; i < nq; i++) {
|
|
739
|
+
float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
|
|
740
|
+
for (int64_t j = 0; j < nb; j++)
|
|
741
|
+
dis[i * ldd + j] = q_norm + b_norms [j];
|
|
742
|
+
}
|
|
743
|
+
|
|
744
|
+
{
|
|
745
|
+
float q_norm = fvec_norm_L2sqr (xq, d);
|
|
746
|
+
for (int64_t j = 0; j < nb; j++)
|
|
747
|
+
dis[j] += q_norm;
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
{
|
|
751
|
+
FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
|
|
752
|
+
float one = 1.0, minus_2 = -2.0;
|
|
753
|
+
|
|
754
|
+
sgemm_ ("Transposed", "Not transposed",
|
|
755
|
+
&nbi, &nqi, &di,
|
|
756
|
+
&minus_2,
|
|
757
|
+
xb, &ldbi,
|
|
758
|
+
xq, &ldqi,
|
|
759
|
+
&one, dis, &lddi);
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
} // namespace faiss
|