faiss 0.1.5 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/README.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +6 -2
- data/ext/faiss/index.cpp +114 -43
- data/ext/faiss/index_binary.cpp +24 -30
- data/ext/faiss/kmeans.cpp +20 -16
- data/ext/faiss/numo.hpp +867 -0
- data/ext/faiss/pca_matrix.cpp +13 -14
- data/ext/faiss/product_quantizer.cpp +23 -24
- data/ext/faiss/utils.cpp +10 -37
- data/ext/faiss/utils.h +2 -13
- data/lib/faiss.rb +0 -5
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +24 -10
- data/lib/faiss/index.rb +0 -20
- data/lib/faiss/index_binary.rb +0 -20
- data/lib/faiss/kmeans.rb +0 -15
- data/lib/faiss/pca_matrix.rb +0 -15
- data/lib/faiss/product_quantizer.rb +0 -22
@@ -5,9 +5,8 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
|
9
|
-
#include <faiss/utils/WorkerThread.h>
|
10
8
|
#include <faiss/impl/FaissAssert.h>
|
9
|
+
#include <faiss/utils/WorkerThread.h>
|
11
10
|
#include <exception>
|
12
11
|
|
13
12
|
namespace faiss {
|
@@ -15,112 +14,104 @@ namespace faiss {
|
|
15
14
|
namespace {
|
16
15
|
|
17
16
|
// Captures any exceptions thrown by the lambda and returns them via the promise
|
18
|
-
void runCallback(std::function<void()>& fn,
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
}
|
17
|
+
void runCallback(std::function<void()>& fn, std::promise<bool>& promise) {
|
18
|
+
try {
|
19
|
+
fn();
|
20
|
+
promise.set_value(true);
|
21
|
+
} catch (...) {
|
22
|
+
promise.set_exception(std::current_exception());
|
23
|
+
}
|
26
24
|
}
|
27
25
|
|
28
26
|
} // namespace
|
29
27
|
|
30
|
-
WorkerThread::WorkerThread() :
|
31
|
-
|
32
|
-
startThread();
|
28
|
+
WorkerThread::WorkerThread() : wantStop_(false) {
|
29
|
+
startThread();
|
33
30
|
|
34
|
-
|
35
|
-
|
31
|
+
// Make sure that the thread has started before continuing
|
32
|
+
add([]() {}).get();
|
36
33
|
}
|
37
34
|
|
38
35
|
WorkerThread::~WorkerThread() {
|
39
|
-
|
40
|
-
|
36
|
+
stop();
|
37
|
+
waitForThreadExit();
|
41
38
|
}
|
42
39
|
|
43
|
-
void
|
44
|
-
|
45
|
-
thread_ = std::thread([this](){ threadMain(); });
|
40
|
+
void WorkerThread::startThread() {
|
41
|
+
thread_ = std::thread([this]() { threadMain(); });
|
46
42
|
}
|
47
43
|
|
48
|
-
void
|
49
|
-
|
50
|
-
std::lock_guard<std::mutex> guard(mutex_);
|
44
|
+
void WorkerThread::stop() {
|
45
|
+
std::lock_guard<std::mutex> guard(mutex_);
|
51
46
|
|
52
|
-
|
53
|
-
|
47
|
+
wantStop_ = true;
|
48
|
+
monitor_.notify_one();
|
54
49
|
}
|
55
50
|
|
56
|
-
std::future<bool>
|
57
|
-
|
58
|
-
std::lock_guard<std::mutex> guard(mutex_);
|
51
|
+
std::future<bool> WorkerThread::add(std::function<void()> f) {
|
52
|
+
std::lock_guard<std::mutex> guard(mutex_);
|
59
53
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
54
|
+
if (wantStop_) {
|
55
|
+
// The timer thread has been stopped, or we want to stop; we can't
|
56
|
+
// schedule anything else
|
57
|
+
std::promise<bool> p;
|
58
|
+
auto fut = p.get_future();
|
65
59
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
60
|
+
// did not execute
|
61
|
+
p.set_value(false);
|
62
|
+
return fut;
|
63
|
+
}
|
70
64
|
|
71
|
-
|
72
|
-
|
65
|
+
auto pr = std::promise<bool>();
|
66
|
+
auto fut = pr.get_future();
|
73
67
|
|
74
|
-
|
68
|
+
queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
|
75
69
|
|
76
|
-
|
77
|
-
|
78
|
-
|
70
|
+
// Wake up our thread
|
71
|
+
monitor_.notify_one();
|
72
|
+
return fut;
|
79
73
|
}
|
80
74
|
|
81
|
-
void
|
82
|
-
|
83
|
-
threadLoop();
|
75
|
+
void WorkerThread::threadMain() {
|
76
|
+
threadLoop();
|
84
77
|
|
85
|
-
|
86
|
-
|
78
|
+
// Call all pending tasks
|
79
|
+
FAISS_ASSERT(wantStop_);
|
87
80
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
81
|
+
// flush all pending operations
|
82
|
+
for (auto& f : queue_) {
|
83
|
+
runCallback(f.first, f.second);
|
84
|
+
}
|
92
85
|
}
|
93
86
|
|
94
|
-
void
|
95
|
-
|
96
|
-
|
97
|
-
std::pair<std::function<void()>, std::promise<bool>> data;
|
87
|
+
void WorkerThread::threadLoop() {
|
88
|
+
while (true) {
|
89
|
+
std::pair<std::function<void()>, std::promise<bool>> data;
|
98
90
|
|
99
|
-
|
100
|
-
|
91
|
+
{
|
92
|
+
std::unique_lock<std::mutex> lock(mutex_);
|
101
93
|
|
102
|
-
|
103
|
-
|
104
|
-
|
94
|
+
while (!wantStop_ && queue_.empty()) {
|
95
|
+
monitor_.wait(lock);
|
96
|
+
}
|
105
97
|
|
106
|
-
|
107
|
-
|
108
|
-
|
98
|
+
if (wantStop_) {
|
99
|
+
return;
|
100
|
+
}
|
109
101
|
|
110
|
-
|
111
|
-
|
112
|
-
|
102
|
+
data = std::move(queue_.front());
|
103
|
+
queue_.pop_front();
|
104
|
+
}
|
113
105
|
|
114
|
-
|
115
|
-
|
106
|
+
runCallback(data.first, data.second);
|
107
|
+
}
|
116
108
|
}
|
117
109
|
|
118
|
-
void
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
}
|
110
|
+
void WorkerThread::waitForThreadExit() {
|
111
|
+
try {
|
112
|
+
thread_.join();
|
113
|
+
} catch (...) {
|
114
|
+
}
|
124
115
|
}
|
125
116
|
|
126
|
-
} // namespace
|
117
|
+
} // namespace faiss
|
@@ -5,57 +5,56 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
|
9
8
|
#pragma once
|
10
9
|
|
11
10
|
#include <condition_variable>
|
12
|
-
#include <future>
|
13
11
|
#include <deque>
|
12
|
+
#include <future>
|
14
13
|
#include <thread>
|
15
14
|
|
16
15
|
namespace faiss {
|
17
16
|
|
18
17
|
class WorkerThread {
|
19
|
-
|
20
|
-
|
18
|
+
public:
|
19
|
+
WorkerThread();
|
21
20
|
|
22
|
-
|
23
|
-
|
24
|
-
|
21
|
+
/// Stops and waits for the worker thread to exit, flushing all
|
22
|
+
/// pending lambdas
|
23
|
+
~WorkerThread();
|
25
24
|
|
26
|
-
|
27
|
-
|
25
|
+
/// Request that the worker thread stop itself
|
26
|
+
void stop();
|
28
27
|
|
29
|
-
|
30
|
-
|
31
|
-
|
28
|
+
/// Blocking waits in the current thread for the worker thread to
|
29
|
+
/// stop
|
30
|
+
void waitForThreadExit();
|
32
31
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
32
|
+
/// Adds a lambda to run on the worker thread; returns a future that
|
33
|
+
/// can be used to block on its completion.
|
34
|
+
/// Future status is `true` if the lambda was run in the worker
|
35
|
+
/// thread; `false` if it was not run, because the worker thread is
|
36
|
+
/// exiting or has exited.
|
37
|
+
std::future<bool> add(std::function<void()> f);
|
39
38
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
39
|
+
private:
|
40
|
+
void startThread();
|
41
|
+
void threadMain();
|
42
|
+
void threadLoop();
|
44
43
|
|
45
|
-
|
46
|
-
|
44
|
+
/// Thread that all queued lambdas are run on
|
45
|
+
std::thread thread_;
|
47
46
|
|
48
|
-
|
49
|
-
|
47
|
+
/// Mutex for the queue and exit status
|
48
|
+
std::mutex mutex_;
|
50
49
|
|
51
|
-
|
52
|
-
|
50
|
+
/// Monitor for the exit status and the queue
|
51
|
+
std::condition_variable monitor_;
|
53
52
|
|
54
|
-
|
55
|
-
|
53
|
+
/// Whether or not we want the thread to exit
|
54
|
+
bool wantStop_;
|
56
55
|
|
57
|
-
|
58
|
-
|
56
|
+
/// Queue of pending lambdas to call
|
57
|
+
std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
|
59
58
|
};
|
60
59
|
|
61
|
-
} // namespace
|
60
|
+
} // namespace faiss
|
@@ -10,10 +10,10 @@
|
|
10
10
|
#include <faiss/utils/distances.h>
|
11
11
|
|
12
12
|
#include <algorithm>
|
13
|
-
#include <cstdio>
|
14
13
|
#include <cassert>
|
15
|
-
#include <cstring>
|
16
14
|
#include <cmath>
|
15
|
+
#include <cstdio>
|
16
|
+
#include <cstring>
|
17
17
|
|
18
18
|
#include <omp.h>
|
19
19
|
|
@@ -21,186 +21,151 @@
|
|
21
21
|
#include <faiss/impl/FaissAssert.h>
|
22
22
|
#include <faiss/impl/ResultHandler.h>
|
23
23
|
|
24
|
-
|
25
|
-
|
26
24
|
#ifndef FINTEGER
|
27
25
|
#define FINTEGER long
|
28
26
|
#endif
|
29
27
|
|
30
|
-
|
31
28
|
extern "C" {
|
32
29
|
|
33
30
|
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
34
31
|
|
35
|
-
int sgemm_
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
32
|
+
int sgemm_(
|
33
|
+
const char* transa,
|
34
|
+
const char* transb,
|
35
|
+
FINTEGER* m,
|
36
|
+
FINTEGER* n,
|
37
|
+
FINTEGER* k,
|
38
|
+
const float* alpha,
|
39
|
+
const float* a,
|
40
|
+
FINTEGER* lda,
|
41
|
+
const float* b,
|
42
|
+
FINTEGER* ldb,
|
43
|
+
float* beta,
|
44
|
+
float* c,
|
45
|
+
FINTEGER* ldc);
|
41
46
|
}
|
42
47
|
|
43
|
-
|
44
48
|
namespace faiss {
|
45
49
|
|
46
|
-
|
47
|
-
|
48
50
|
/***************************************************************************
|
49
51
|
* Matrix/vector ops
|
50
52
|
***************************************************************************/
|
51
53
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
54
|
/* Compute the L2 norm of a set of nx vectors */
|
56
|
-
void fvec_norms_L2
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
55
|
+
void fvec_norms_L2(
|
56
|
+
float* __restrict nr,
|
57
|
+
const float* __restrict x,
|
58
|
+
size_t d,
|
59
|
+
size_t nx) {
|
61
60
|
#pragma omp parallel for
|
62
61
|
for (int64_t i = 0; i < nx; i++) {
|
63
|
-
nr[i] = sqrtf
|
62
|
+
nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
|
64
63
|
}
|
65
64
|
}
|
66
65
|
|
67
|
-
void fvec_norms_L2sqr
|
68
|
-
|
69
|
-
|
70
|
-
|
66
|
+
void fvec_norms_L2sqr(
|
67
|
+
float* __restrict nr,
|
68
|
+
const float* __restrict x,
|
69
|
+
size_t d,
|
70
|
+
size_t nx) {
|
71
71
|
#pragma omp parallel for
|
72
72
|
for (int64_t i = 0; i < nx; i++)
|
73
|
-
nr[i] = fvec_norm_L2sqr
|
73
|
+
nr[i] = fvec_norm_L2sqr(x + i * d, d);
|
74
74
|
}
|
75
75
|
|
76
|
-
|
77
|
-
|
78
|
-
void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
79
|
-
{
|
76
|
+
void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
80
77
|
#pragma omp parallel for
|
81
78
|
for (int64_t i = 0; i < nx; i++) {
|
82
|
-
float
|
79
|
+
float* __restrict xi = x + i * d;
|
83
80
|
|
84
|
-
float nr = fvec_norm_L2sqr
|
81
|
+
float nr = fvec_norm_L2sqr(xi, d);
|
85
82
|
|
86
83
|
if (nr > 0) {
|
87
84
|
size_t j;
|
88
|
-
const float inv_nr = 1.0 / sqrtf
|
85
|
+
const float inv_nr = 1.0 / sqrtf(nr);
|
89
86
|
for (j = 0; j < d; j++)
|
90
87
|
xi[j] *= inv_nr;
|
91
88
|
}
|
92
89
|
}
|
93
90
|
}
|
94
91
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
92
|
/***************************************************************************
|
107
93
|
* KNN functions
|
108
94
|
***************************************************************************/
|
109
95
|
|
110
96
|
namespace {
|
111
97
|
|
112
|
-
|
113
|
-
|
114
98
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
115
|
-
template<class ResultHandler>
|
116
|
-
void exhaustive_inner_product_seq
|
117
|
-
const float
|
118
|
-
const float
|
119
|
-
size_t d,
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
check_period *= omp_get_max_threads();
|
125
|
-
|
99
|
+
template <class ResultHandler>
|
100
|
+
void exhaustive_inner_product_seq(
|
101
|
+
const float* x,
|
102
|
+
const float* y,
|
103
|
+
size_t d,
|
104
|
+
size_t nx,
|
105
|
+
size_t ny,
|
106
|
+
ResultHandler& res) {
|
126
107
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
127
108
|
|
128
|
-
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
129
|
-
size_t i1 = std::min(i0 + check_period, nx);
|
130
|
-
|
131
109
|
#pragma omp parallel
|
132
|
-
|
133
|
-
|
110
|
+
{
|
111
|
+
SingleResultHandler resi(res);
|
134
112
|
#pragma omp for
|
135
|
-
|
136
|
-
|
137
|
-
|
113
|
+
for (int64_t i = 0; i < nx; i++) {
|
114
|
+
const float* x_i = x + i * d;
|
115
|
+
const float* y_j = y;
|
138
116
|
|
139
|
-
|
117
|
+
resi.begin(i);
|
140
118
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
}
|
146
|
-
resi.end();
|
119
|
+
for (size_t j = 0; j < ny; j++) {
|
120
|
+
float ip = fvec_inner_product(x_i, y_j, d);
|
121
|
+
resi.add_result(ip, j);
|
122
|
+
y_j += d;
|
147
123
|
}
|
124
|
+
resi.end();
|
148
125
|
}
|
149
|
-
InterruptCallback::check ();
|
150
126
|
}
|
151
|
-
|
152
127
|
}
|
153
128
|
|
154
|
-
template<class ResultHandler>
|
155
|
-
void exhaustive_L2sqr_seq
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
163
|
-
check_period *= omp_get_max_threads();
|
129
|
+
template <class ResultHandler>
|
130
|
+
void exhaustive_L2sqr_seq(
|
131
|
+
const float* x,
|
132
|
+
const float* y,
|
133
|
+
size_t d,
|
134
|
+
size_t nx,
|
135
|
+
size_t ny,
|
136
|
+
ResultHandler& res) {
|
164
137
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
165
138
|
|
166
|
-
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
167
|
-
size_t i1 = std::min(i0 + check_period, nx);
|
168
|
-
|
169
139
|
#pragma omp parallel
|
170
|
-
|
171
|
-
|
140
|
+
{
|
141
|
+
SingleResultHandler resi(res);
|
172
142
|
#pragma omp for
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
}
|
182
|
-
resi.end();
|
143
|
+
for (int64_t i = 0; i < nx; i++) {
|
144
|
+
const float* x_i = x + i * d;
|
145
|
+
const float* y_j = y;
|
146
|
+
resi.begin(i);
|
147
|
+
for (size_t j = 0; j < ny; j++) {
|
148
|
+
float disij = fvec_L2sqr(x_i, y_j, d);
|
149
|
+
resi.add_result(disij, j);
|
150
|
+
y_j += d;
|
183
151
|
}
|
152
|
+
resi.end();
|
184
153
|
}
|
185
|
-
InterruptCallback::check ();
|
186
154
|
}
|
187
|
-
|
188
|
-
};
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
155
|
+
}
|
193
156
|
|
194
157
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
195
|
-
template<class ResultHandler>
|
196
|
-
void exhaustive_inner_product_blas
|
197
|
-
const float
|
198
|
-
const float
|
199
|
-
size_t d,
|
200
|
-
|
201
|
-
|
158
|
+
template <class ResultHandler>
|
159
|
+
void exhaustive_inner_product_blas(
|
160
|
+
const float* x,
|
161
|
+
const float* y,
|
162
|
+
size_t d,
|
163
|
+
size_t nx,
|
164
|
+
size_t ny,
|
165
|
+
ResultHandler& res) {
|
202
166
|
// BLAS does not like empty matrices
|
203
|
-
if (nx == 0 || ny == 0)
|
167
|
+
if (nx == 0 || ny == 0)
|
168
|
+
return;
|
204
169
|
|
205
170
|
/* block sizes */
|
206
171
|
const size_t bs_x = distance_compute_blas_query_bs;
|
@@ -209,86 +174,105 @@ void exhaustive_inner_product_blas (
|
|
209
174
|
|
210
175
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
211
176
|
size_t i1 = i0 + bs_x;
|
212
|
-
if(i1 > nx)
|
177
|
+
if (i1 > nx)
|
178
|
+
i1 = nx;
|
213
179
|
|
214
180
|
res.begin_multiple(i0, i1);
|
215
181
|
|
216
182
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
217
183
|
size_t j1 = j0 + bs_y;
|
218
|
-
if (j1 > ny)
|
184
|
+
if (j1 > ny)
|
185
|
+
j1 = ny;
|
219
186
|
/* compute the actual dot products */
|
220
187
|
{
|
221
188
|
float one = 1, zero = 0;
|
222
189
|
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
223
|
-
sgemm_
|
224
|
-
|
225
|
-
|
226
|
-
|
190
|
+
sgemm_("Transpose",
|
191
|
+
"Not transpose",
|
192
|
+
&nyi,
|
193
|
+
&nxi,
|
194
|
+
&di,
|
195
|
+
&one,
|
196
|
+
y + j0 * d,
|
197
|
+
&di,
|
198
|
+
x + i0 * d,
|
199
|
+
&di,
|
200
|
+
&zero,
|
201
|
+
ip_block.get(),
|
202
|
+
&nyi);
|
227
203
|
}
|
228
204
|
|
229
205
|
res.add_results(j0, j1, ip_block.get());
|
230
|
-
|
231
206
|
}
|
232
207
|
res.end_multiple();
|
233
|
-
InterruptCallback::check
|
234
|
-
|
208
|
+
InterruptCallback::check();
|
235
209
|
}
|
236
210
|
}
|
237
211
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
212
|
// distance correction is an operator that can be applied to transform
|
242
213
|
// the distances
|
243
|
-
template<class ResultHandler>
|
244
|
-
void exhaustive_L2sqr_blas
|
245
|
-
const float
|
246
|
-
const float
|
247
|
-
size_t d,
|
248
|
-
|
249
|
-
|
250
|
-
|
214
|
+
template <class ResultHandler>
|
215
|
+
void exhaustive_L2sqr_blas(
|
216
|
+
const float* x,
|
217
|
+
const float* y,
|
218
|
+
size_t d,
|
219
|
+
size_t nx,
|
220
|
+
size_t ny,
|
221
|
+
ResultHandler& res,
|
222
|
+
const float* y_norms = nullptr) {
|
251
223
|
// BLAS does not like empty matrices
|
252
|
-
if (nx == 0 || ny == 0)
|
224
|
+
if (nx == 0 || ny == 0)
|
225
|
+
return;
|
253
226
|
|
254
227
|
/* block sizes */
|
255
228
|
const size_t bs_x = distance_compute_blas_query_bs;
|
256
229
|
const size_t bs_y = distance_compute_blas_database_bs;
|
257
230
|
// const size_t bs_x = 16, bs_y = 16;
|
258
|
-
std::unique_ptr<float
|
259
|
-
std::unique_ptr<float
|
260
|
-
std::unique_ptr<float
|
231
|
+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
232
|
+
std::unique_ptr<float[]> x_norms(new float[nx]);
|
233
|
+
std::unique_ptr<float[]> del2;
|
261
234
|
|
262
|
-
fvec_norms_L2sqr
|
235
|
+
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
263
236
|
|
264
237
|
if (!y_norms) {
|
265
|
-
float
|
238
|
+
float* y_norms2 = new float[ny];
|
266
239
|
del2.reset(y_norms2);
|
267
|
-
fvec_norms_L2sqr
|
240
|
+
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
268
241
|
y_norms = y_norms2;
|
269
242
|
}
|
270
243
|
|
271
244
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
272
245
|
size_t i1 = i0 + bs_x;
|
273
|
-
if(i1 > nx)
|
246
|
+
if (i1 > nx)
|
247
|
+
i1 = nx;
|
274
248
|
|
275
249
|
res.begin_multiple(i0, i1);
|
276
250
|
|
277
251
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
278
252
|
size_t j1 = j0 + bs_y;
|
279
|
-
if (j1 > ny)
|
253
|
+
if (j1 > ny)
|
254
|
+
j1 = ny;
|
280
255
|
/* compute the actual dot products */
|
281
256
|
{
|
282
257
|
float one = 1, zero = 0;
|
283
258
|
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
284
|
-
sgemm_
|
285
|
-
|
286
|
-
|
287
|
-
|
259
|
+
sgemm_("Transpose",
|
260
|
+
"Not transpose",
|
261
|
+
&nyi,
|
262
|
+
&nxi,
|
263
|
+
&di,
|
264
|
+
&one,
|
265
|
+
y + j0 * d,
|
266
|
+
&di,
|
267
|
+
x + i0 * d,
|
268
|
+
&di,
|
269
|
+
&zero,
|
270
|
+
ip_block.get(),
|
271
|
+
&nyi);
|
288
272
|
}
|
289
|
-
|
273
|
+
#pragma omp parallel for
|
290
274
|
for (int64_t i = i0; i < i1; i++) {
|
291
|
-
float
|
275
|
+
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
292
276
|
|
293
277
|
for (size_t j = j0; j < j1; j++) {
|
294
278
|
float ip = *ip_line;
|
@@ -296,7 +280,8 @@ void exhaustive_L2sqr_blas (
|
|
296
280
|
|
297
281
|
// negative values can occur for identical vectors
|
298
282
|
// due to roundoff errors
|
299
|
-
if (dis < 0)
|
283
|
+
if (dis < 0)
|
284
|
+
dis = 0;
|
300
285
|
|
301
286
|
*ip_line = dis;
|
302
287
|
ip_line++;
|
@@ -305,18 +290,12 @@ void exhaustive_L2sqr_blas (
|
|
305
290
|
res.add_results(j0, j1, ip_block.get());
|
306
291
|
}
|
307
292
|
res.end_multiple();
|
308
|
-
InterruptCallback::check
|
293
|
+
InterruptCallback::check();
|
309
294
|
}
|
310
295
|
}
|
311
296
|
|
312
|
-
|
313
|
-
|
314
297
|
} // anonymous namespace
|
315
298
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
299
|
/*******************************************************
|
321
300
|
* KNN driver functions
|
322
301
|
*******************************************************/
|
@@ -326,268 +305,275 @@ int distance_compute_blas_query_bs = 4096;
|
|
326
305
|
int distance_compute_blas_database_bs = 1024;
|
327
306
|
int distance_compute_min_k_reservoir = 100;
|
328
307
|
|
329
|
-
void knn_inner_product
|
330
|
-
const float
|
331
|
-
|
332
|
-
|
333
|
-
|
308
|
+
void knn_inner_product(
|
309
|
+
const float* x,
|
310
|
+
const float* y,
|
311
|
+
size_t d,
|
312
|
+
size_t nx,
|
313
|
+
size_t ny,
|
314
|
+
float_minheap_array_t* ha) {
|
334
315
|
if (ha->k < distance_compute_min_k_reservoir) {
|
335
316
|
HeapResultHandler<CMin<float, int64_t>> res(
|
336
|
-
|
317
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
337
318
|
if (nx < distance_compute_blas_threshold) {
|
338
|
-
exhaustive_inner_product_seq
|
319
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
339
320
|
} else {
|
340
|
-
exhaustive_inner_product_blas
|
321
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
341
322
|
}
|
342
323
|
} else {
|
343
324
|
ReservoirResultHandler<CMin<float, int64_t>> res(
|
344
|
-
|
325
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
345
326
|
if (nx < distance_compute_blas_threshold) {
|
346
|
-
exhaustive_inner_product_seq
|
327
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
347
328
|
} else {
|
348
|
-
exhaustive_inner_product_blas
|
329
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
349
330
|
}
|
350
331
|
}
|
351
332
|
}
|
352
333
|
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
const float *y_norm2
|
362
|
-
) {
|
363
|
-
|
334
|
+
void knn_L2sqr(
|
335
|
+
const float* x,
|
336
|
+
const float* y,
|
337
|
+
size_t d,
|
338
|
+
size_t nx,
|
339
|
+
size_t ny,
|
340
|
+
float_maxheap_array_t* ha,
|
341
|
+
const float* y_norm2) {
|
364
342
|
if (ha->k < distance_compute_min_k_reservoir) {
|
365
343
|
HeapResultHandler<CMax<float, int64_t>> res(
|
366
|
-
|
344
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
367
345
|
|
368
346
|
if (nx < distance_compute_blas_threshold) {
|
369
|
-
exhaustive_L2sqr_seq
|
347
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
370
348
|
} else {
|
371
|
-
exhaustive_L2sqr_blas
|
349
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
372
350
|
}
|
373
351
|
} else {
|
374
352
|
ReservoirResultHandler<CMax<float, int64_t>> res(
|
375
|
-
|
353
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
376
354
|
if (nx < distance_compute_blas_threshold) {
|
377
|
-
exhaustive_L2sqr_seq
|
355
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
378
356
|
} else {
|
379
|
-
exhaustive_L2sqr_blas
|
357
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
380
358
|
}
|
381
359
|
}
|
382
360
|
}
|
383
361
|
|
384
|
-
|
385
362
|
/***************************************************************************
|
386
363
|
* Range search
|
387
364
|
***************************************************************************/
|
388
365
|
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
size_t d, size_t nx, size_t ny,
|
366
|
+
void range_search_L2sqr(
|
367
|
+
const float* x,
|
368
|
+
const float* y,
|
369
|
+
size_t d,
|
370
|
+
size_t nx,
|
371
|
+
size_t ny,
|
396
372
|
float radius,
|
397
|
-
RangeSearchResult
|
398
|
-
{
|
373
|
+
RangeSearchResult* res) {
|
399
374
|
RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
|
400
375
|
if (nx < distance_compute_blas_threshold) {
|
401
|
-
exhaustive_L2sqr_seq
|
376
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
|
402
377
|
} else {
|
403
|
-
exhaustive_L2sqr_blas
|
378
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
|
404
379
|
}
|
405
380
|
}
|
406
381
|
|
407
|
-
void range_search_inner_product
|
408
|
-
const float
|
409
|
-
const float
|
410
|
-
size_t d,
|
382
|
+
void range_search_inner_product(
|
383
|
+
const float* x,
|
384
|
+
const float* y,
|
385
|
+
size_t d,
|
386
|
+
size_t nx,
|
387
|
+
size_t ny,
|
411
388
|
float radius,
|
412
|
-
RangeSearchResult
|
413
|
-
{
|
414
|
-
|
389
|
+
RangeSearchResult* res) {
|
415
390
|
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
|
416
391
|
if (nx < distance_compute_blas_threshold) {
|
417
|
-
exhaustive_inner_product_seq
|
392
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
|
418
393
|
} else {
|
419
|
-
exhaustive_inner_product_blas
|
394
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
|
420
395
|
}
|
421
396
|
}
|
422
397
|
|
423
|
-
|
424
398
|
/***************************************************************************
|
425
399
|
* compute a subset of distances
|
426
400
|
***************************************************************************/
|
427
401
|
|
428
402
|
/* compute the inner product between x and a subset y of ny vectors,
|
429
403
|
whose indices are given by idy. */
|
430
|
-
void fvec_inner_products_by_idx
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
404
|
+
void fvec_inner_products_by_idx(
|
405
|
+
float* __restrict ip,
|
406
|
+
const float* x,
|
407
|
+
const float* y,
|
408
|
+
const int64_t* __restrict ids, /* for y vecs */
|
409
|
+
size_t d,
|
410
|
+
size_t nx,
|
411
|
+
size_t ny) {
|
436
412
|
#pragma omp parallel for
|
437
413
|
for (int64_t j = 0; j < nx; j++) {
|
438
|
-
const int64_t
|
439
|
-
const float
|
440
|
-
float
|
414
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
415
|
+
const float* xj = x + j * d;
|
416
|
+
float* __restrict ipj = ip + j * ny;
|
441
417
|
for (size_t i = 0; i < ny; i++) {
|
442
418
|
if (idsj[i] < 0)
|
443
419
|
continue;
|
444
|
-
ipj[i] = fvec_inner_product
|
420
|
+
ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
|
445
421
|
}
|
446
422
|
}
|
447
423
|
}
|
448
424
|
|
449
|
-
|
450
|
-
|
451
425
|
/* compute the inner product between x and a subset y of ny vectors,
|
452
426
|
whose indices are given by idy. */
|
453
|
-
void fvec_L2sqr_by_idx
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
427
|
+
void fvec_L2sqr_by_idx(
|
428
|
+
float* __restrict dis,
|
429
|
+
const float* x,
|
430
|
+
const float* y,
|
431
|
+
const int64_t* __restrict ids, /* ids of y vecs */
|
432
|
+
size_t d,
|
433
|
+
size_t nx,
|
434
|
+
size_t ny) {
|
459
435
|
#pragma omp parallel for
|
460
436
|
for (int64_t j = 0; j < nx; j++) {
|
461
|
-
const int64_t
|
462
|
-
const float
|
463
|
-
float
|
437
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
438
|
+
const float* xj = x + j * d;
|
439
|
+
float* __restrict disj = dis + j * ny;
|
464
440
|
for (size_t i = 0; i < ny; i++) {
|
465
441
|
if (idsj[i] < 0)
|
466
442
|
continue;
|
467
|
-
disj[i] = fvec_L2sqr
|
443
|
+
disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
|
468
444
|
}
|
469
445
|
}
|
470
446
|
}
|
471
447
|
|
472
|
-
void pairwise_indexed_L2sqr
|
473
|
-
size_t d,
|
474
|
-
|
475
|
-
const float
|
476
|
-
|
477
|
-
|
448
|
+
void pairwise_indexed_L2sqr(
|
449
|
+
size_t d,
|
450
|
+
size_t n,
|
451
|
+
const float* x,
|
452
|
+
const int64_t* ix,
|
453
|
+
const float* y,
|
454
|
+
const int64_t* iy,
|
455
|
+
float* dis) {
|
478
456
|
#pragma omp parallel for
|
479
457
|
for (int64_t j = 0; j < n; j++) {
|
480
458
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
481
|
-
dis[j] = fvec_L2sqr
|
459
|
+
dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
|
482
460
|
}
|
483
461
|
}
|
484
462
|
}
|
485
463
|
|
486
|
-
void pairwise_indexed_inner_product
|
487
|
-
size_t d,
|
488
|
-
|
489
|
-
const float
|
490
|
-
|
491
|
-
|
464
|
+
void pairwise_indexed_inner_product(
|
465
|
+
size_t d,
|
466
|
+
size_t n,
|
467
|
+
const float* x,
|
468
|
+
const int64_t* ix,
|
469
|
+
const float* y,
|
470
|
+
const int64_t* iy,
|
471
|
+
float* dis) {
|
492
472
|
#pragma omp parallel for
|
493
473
|
for (int64_t j = 0; j < n; j++) {
|
494
474
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
495
|
-
dis[j] = fvec_inner_product
|
475
|
+
dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
|
496
476
|
}
|
497
477
|
}
|
498
478
|
}
|
499
479
|
|
500
|
-
|
501
480
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
502
481
|
indexed by ids. May be useful for re-ranking a pre-selected vector list */
|
503
|
-
void knn_inner_products_by_idx
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
482
|
+
void knn_inner_products_by_idx(
|
483
|
+
const float* x,
|
484
|
+
const float* y,
|
485
|
+
const int64_t* ids,
|
486
|
+
size_t d,
|
487
|
+
size_t nx,
|
488
|
+
size_t ny,
|
489
|
+
float_minheap_array_t* res) {
|
509
490
|
size_t k = res->k;
|
510
491
|
|
511
492
|
#pragma omp parallel for
|
512
493
|
for (int64_t i = 0; i < nx; i++) {
|
513
|
-
const float
|
514
|
-
const int64_t
|
494
|
+
const float* x_ = x + i * d;
|
495
|
+
const int64_t* idsi = ids + i * ny;
|
515
496
|
size_t j;
|
516
|
-
float
|
517
|
-
int64_t
|
518
|
-
minheap_heapify
|
497
|
+
float* __restrict simi = res->get_val(i);
|
498
|
+
int64_t* __restrict idxi = res->get_ids(i);
|
499
|
+
minheap_heapify(k, simi, idxi);
|
519
500
|
|
520
501
|
for (j = 0; j < ny; j++) {
|
521
|
-
if (idsi[j] < 0)
|
522
|
-
|
502
|
+
if (idsi[j] < 0)
|
503
|
+
break;
|
504
|
+
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
|
523
505
|
|
524
506
|
if (ip > simi[0]) {
|
525
|
-
minheap_replace_top
|
507
|
+
minheap_replace_top(k, simi, idxi, ip, idsi[j]);
|
526
508
|
}
|
527
509
|
}
|
528
|
-
minheap_reorder
|
510
|
+
minheap_reorder(k, simi, idxi);
|
529
511
|
}
|
530
|
-
|
531
512
|
}
|
532
513
|
|
533
|
-
void knn_L2sqr_by_idx
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
514
|
+
void knn_L2sqr_by_idx(
|
515
|
+
const float* x,
|
516
|
+
const float* y,
|
517
|
+
const int64_t* __restrict ids,
|
518
|
+
size_t d,
|
519
|
+
size_t nx,
|
520
|
+
size_t ny,
|
521
|
+
float_maxheap_array_t* res) {
|
539
522
|
size_t k = res->k;
|
540
523
|
|
541
524
|
#pragma omp parallel for
|
542
525
|
for (int64_t i = 0; i < nx; i++) {
|
543
|
-
const float
|
544
|
-
const int64_t
|
545
|
-
float
|
546
|
-
int64_t
|
547
|
-
maxheap_heapify
|
526
|
+
const float* x_ = x + i * d;
|
527
|
+
const int64_t* __restrict idsi = ids + i * ny;
|
528
|
+
float* __restrict simi = res->get_val(i);
|
529
|
+
int64_t* __restrict idxi = res->get_ids(i);
|
530
|
+
maxheap_heapify(res->k, simi, idxi);
|
548
531
|
for (size_t j = 0; j < ny; j++) {
|
549
|
-
float disij = fvec_L2sqr
|
532
|
+
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
550
533
|
|
551
534
|
if (disij < simi[0]) {
|
552
|
-
maxheap_replace_top
|
535
|
+
maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
|
553
536
|
}
|
554
537
|
}
|
555
|
-
maxheap_reorder
|
538
|
+
maxheap_reorder(res->k, simi, idxi);
|
556
539
|
}
|
557
|
-
|
558
540
|
}
|
559
541
|
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
{
|
570
|
-
if (nq == 0 || nb == 0)
|
571
|
-
|
572
|
-
if (
|
573
|
-
|
542
|
+
void pairwise_L2sqr(
|
543
|
+
int64_t d,
|
544
|
+
int64_t nq,
|
545
|
+
const float* xq,
|
546
|
+
int64_t nb,
|
547
|
+
const float* xb,
|
548
|
+
float* dis,
|
549
|
+
int64_t ldq,
|
550
|
+
int64_t ldb,
|
551
|
+
int64_t ldd) {
|
552
|
+
if (nq == 0 || nb == 0)
|
553
|
+
return;
|
554
|
+
if (ldq == -1)
|
555
|
+
ldq = d;
|
556
|
+
if (ldb == -1)
|
557
|
+
ldb = d;
|
558
|
+
if (ldd == -1)
|
559
|
+
ldd = nb;
|
574
560
|
|
575
561
|
// store in beginning of distance matrix to avoid malloc
|
576
|
-
float
|
562
|
+
float* b_norms = dis;
|
577
563
|
|
578
564
|
#pragma omp parallel for
|
579
565
|
for (int64_t i = 0; i < nb; i++)
|
580
|
-
b_norms
|
566
|
+
b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
|
581
567
|
|
582
568
|
#pragma omp parallel for
|
583
569
|
for (int64_t i = 1; i < nq; i++) {
|
584
|
-
float q_norm = fvec_norm_L2sqr
|
570
|
+
float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
|
585
571
|
for (int64_t j = 0; j < nb; j++)
|
586
|
-
dis[i * ldd + j] = q_norm + b_norms
|
572
|
+
dis[i * ldd + j] = q_norm + b_norms[j];
|
587
573
|
}
|
588
574
|
|
589
575
|
{
|
590
|
-
float q_norm = fvec_norm_L2sqr
|
576
|
+
float q_norm = fvec_norm_L2sqr(xq, d);
|
591
577
|
for (int64_t j = 0; j < nb; j++)
|
592
578
|
dis[j] += q_norm;
|
593
579
|
}
|
@@ -596,22 +582,28 @@ void pairwise_L2sqr (int64_t d,
|
|
596
582
|
FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
|
597
583
|
float one = 1.0, minus_2 = -2.0;
|
598
584
|
|
599
|
-
sgemm_
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
585
|
+
sgemm_("Transposed",
|
586
|
+
"Not transposed",
|
587
|
+
&nbi,
|
588
|
+
&nqi,
|
589
|
+
&di,
|
590
|
+
&minus_2,
|
591
|
+
xb,
|
592
|
+
&ldbi,
|
593
|
+
xq,
|
594
|
+
&ldqi,
|
595
|
+
&one,
|
596
|
+
dis,
|
597
|
+
&lddi);
|
605
598
|
}
|
606
|
-
|
607
599
|
}
|
608
600
|
|
609
|
-
void inner_product_to_L2sqr(
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
601
|
+
void inner_product_to_L2sqr(
|
602
|
+
float* __restrict dis,
|
603
|
+
const float* nr1,
|
604
|
+
const float* nr2,
|
605
|
+
size_t n1,
|
606
|
+
size_t n2) {
|
615
607
|
#pragma omp parallel for
|
616
608
|
for (int64_t j = 0; j < n1; j++) {
|
617
609
|
float* disj = dis + j * n2;
|
@@ -620,5 +612,4 @@ void inner_product_to_L2sqr(float* __restrict dis,
|
|
620
612
|
}
|
621
613
|
}
|
622
614
|
|
623
|
-
|
624
615
|
} // namespace faiss
|