faiss 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
@@ -0,0 +1,126 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
|
9
|
+
#include <faiss/utils/WorkerThread.h>
|
10
|
+
#include <faiss/impl/FaissAssert.h>
|
11
|
+
#include <exception>
|
12
|
+
|
13
|
+
namespace faiss {
|
14
|
+
|
15
|
+
namespace {
|
16
|
+
|
17
|
+
// Captures any exceptions thrown by the lambda and returns them via the promise
|
18
|
+
void runCallback(std::function<void()>& fn,
|
19
|
+
std::promise<bool>& promise) {
|
20
|
+
try {
|
21
|
+
fn();
|
22
|
+
promise.set_value(true);
|
23
|
+
} catch (...) {
|
24
|
+
promise.set_exception(std::current_exception());
|
25
|
+
}
|
26
|
+
}
|
27
|
+
|
28
|
+
} // namespace
|
29
|
+
|
30
|
+
WorkerThread::WorkerThread() :
|
31
|
+
wantStop_(false) {
|
32
|
+
startThread();
|
33
|
+
|
34
|
+
// Make sure that the thread has started before continuing
|
35
|
+
add([](){}).get();
|
36
|
+
}
|
37
|
+
|
38
|
+
WorkerThread::~WorkerThread() {
|
39
|
+
stop();
|
40
|
+
waitForThreadExit();
|
41
|
+
}
|
42
|
+
|
43
|
+
void
|
44
|
+
WorkerThread::startThread() {
|
45
|
+
thread_ = std::thread([this](){ threadMain(); });
|
46
|
+
}
|
47
|
+
|
48
|
+
void
|
49
|
+
WorkerThread::stop() {
|
50
|
+
std::lock_guard<std::mutex> guard(mutex_);
|
51
|
+
|
52
|
+
wantStop_ = true;
|
53
|
+
monitor_.notify_one();
|
54
|
+
}
|
55
|
+
|
56
|
+
std::future<bool>
|
57
|
+
WorkerThread::add(std::function<void()> f) {
|
58
|
+
std::lock_guard<std::mutex> guard(mutex_);
|
59
|
+
|
60
|
+
if (wantStop_) {
|
61
|
+
// The timer thread has been stopped, or we want to stop; we can't
|
62
|
+
// schedule anything else
|
63
|
+
std::promise<bool> p;
|
64
|
+
auto fut = p.get_future();
|
65
|
+
|
66
|
+
// did not execute
|
67
|
+
p.set_value(false);
|
68
|
+
return fut;
|
69
|
+
}
|
70
|
+
|
71
|
+
auto pr = std::promise<bool>();
|
72
|
+
auto fut = pr.get_future();
|
73
|
+
|
74
|
+
queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
|
75
|
+
|
76
|
+
// Wake up our thread
|
77
|
+
monitor_.notify_one();
|
78
|
+
return fut;
|
79
|
+
}
|
80
|
+
|
81
|
+
void
|
82
|
+
WorkerThread::threadMain() {
|
83
|
+
threadLoop();
|
84
|
+
|
85
|
+
// Call all pending tasks
|
86
|
+
FAISS_ASSERT(wantStop_);
|
87
|
+
|
88
|
+
// flush all pending operations
|
89
|
+
for (auto& f : queue_) {
|
90
|
+
runCallback(f.first, f.second);
|
91
|
+
}
|
92
|
+
}
|
93
|
+
|
94
|
+
void
|
95
|
+
WorkerThread::threadLoop() {
|
96
|
+
while (true) {
|
97
|
+
std::pair<std::function<void()>, std::promise<bool>> data;
|
98
|
+
|
99
|
+
{
|
100
|
+
std::unique_lock<std::mutex> lock(mutex_);
|
101
|
+
|
102
|
+
while (!wantStop_ && queue_.empty()) {
|
103
|
+
monitor_.wait(lock);
|
104
|
+
}
|
105
|
+
|
106
|
+
if (wantStop_) {
|
107
|
+
return;
|
108
|
+
}
|
109
|
+
|
110
|
+
data = std::move(queue_.front());
|
111
|
+
queue_.pop_front();
|
112
|
+
}
|
113
|
+
|
114
|
+
runCallback(data.first, data.second);
|
115
|
+
}
|
116
|
+
}
|
117
|
+
|
118
|
+
void
|
119
|
+
WorkerThread::waitForThreadExit() {
|
120
|
+
try {
|
121
|
+
thread_.join();
|
122
|
+
} catch (...) {
|
123
|
+
}
|
124
|
+
}
|
125
|
+
|
126
|
+
} // namespace
|
@@ -0,0 +1,61 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <condition_variable>
|
12
|
+
#include <future>
|
13
|
+
#include <deque>
|
14
|
+
#include <thread>
|
15
|
+
|
16
|
+
namespace faiss {
|
17
|
+
|
18
|
+
class WorkerThread {
|
19
|
+
public:
|
20
|
+
WorkerThread();
|
21
|
+
|
22
|
+
/// Stops and waits for the worker thread to exit, flushing all
|
23
|
+
/// pending lambdas
|
24
|
+
~WorkerThread();
|
25
|
+
|
26
|
+
/// Request that the worker thread stop itself
|
27
|
+
void stop();
|
28
|
+
|
29
|
+
/// Blocking waits in the current thread for the worker thread to
|
30
|
+
/// stop
|
31
|
+
void waitForThreadExit();
|
32
|
+
|
33
|
+
/// Adds a lambda to run on the worker thread; returns a future that
|
34
|
+
/// can be used to block on its completion.
|
35
|
+
/// Future status is `true` if the lambda was run in the worker
|
36
|
+
/// thread; `false` if it was not run, because the worker thread is
|
37
|
+
/// exiting or has exited.
|
38
|
+
std::future<bool> add(std::function<void()> f);
|
39
|
+
|
40
|
+
private:
|
41
|
+
void startThread();
|
42
|
+
void threadMain();
|
43
|
+
void threadLoop();
|
44
|
+
|
45
|
+
/// Thread that all queued lambdas are run on
|
46
|
+
std::thread thread_;
|
47
|
+
|
48
|
+
/// Mutex for the queue and exit status
|
49
|
+
std::mutex mutex_;
|
50
|
+
|
51
|
+
/// Monitor for the exit status and the queue
|
52
|
+
std::condition_variable monitor_;
|
53
|
+
|
54
|
+
/// Whether or not we want the thread to exit
|
55
|
+
bool wantStop_;
|
56
|
+
|
57
|
+
/// Queue of pending lambdas to call
|
58
|
+
std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
|
59
|
+
};
|
60
|
+
|
61
|
+
} // namespace
|
@@ -0,0 +1,765 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/utils/distances.h>
|
11
|
+
|
12
|
+
#include <cstdio>
|
13
|
+
#include <cassert>
|
14
|
+
#include <cstring>
|
15
|
+
#include <cmath>
|
16
|
+
|
17
|
+
#include <omp.h>
|
18
|
+
|
19
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
21
|
+
|
22
|
+
|
23
|
+
|
24
|
+
#ifndef FINTEGER
|
25
|
+
#define FINTEGER long
|
26
|
+
#endif
|
27
|
+
|
28
|
+
|
29
|
+
extern "C" {
|
30
|
+
|
31
|
+
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
32
|
+
|
33
|
+
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
34
|
+
n, FINTEGER *k, const float *alpha, const float *a,
|
35
|
+
FINTEGER *lda, const float *b, FINTEGER *
|
36
|
+
ldb, float *beta, float *c, FINTEGER *ldc);
|
37
|
+
|
38
|
+
/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
|
39
|
+
|
40
|
+
int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
|
41
|
+
float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
|
42
|
+
|
43
|
+
int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
|
44
|
+
const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
|
45
|
+
float *beta, float *y, FINTEGER *incy);
|
46
|
+
|
47
|
+
}
|
48
|
+
|
49
|
+
|
50
|
+
namespace faiss {
|
51
|
+
|
52
|
+
|
53
|
+
|
54
|
+
/***************************************************************************
|
55
|
+
* Matrix/vector ops
|
56
|
+
***************************************************************************/
|
57
|
+
|
58
|
+
|
59
|
+
|
60
|
+
/* Compute the inner product between a vector x and
|
61
|
+
a set of ny vectors y.
|
62
|
+
These functions are not intended to replace BLAS matrix-matrix, as they
|
63
|
+
would be significantly less efficient in this case. */
|
64
|
+
void fvec_inner_products_ny (float * ip,
|
65
|
+
const float * x,
|
66
|
+
const float * y,
|
67
|
+
size_t d, size_t ny)
|
68
|
+
{
|
69
|
+
// Not sure which one is fastest
|
70
|
+
#if 0
|
71
|
+
{
|
72
|
+
FINTEGER di = d;
|
73
|
+
FINTEGER nyi = ny;
|
74
|
+
float one = 1.0, zero = 0.0;
|
75
|
+
FINTEGER onei = 1;
|
76
|
+
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
77
|
+
}
|
78
|
+
#endif
|
79
|
+
for (size_t i = 0; i < ny; i++) {
|
80
|
+
ip[i] = fvec_inner_product (x, y, d);
|
81
|
+
y += d;
|
82
|
+
}
|
83
|
+
}
|
84
|
+
|
85
|
+
|
86
|
+
|
87
|
+
|
88
|
+
|
89
|
+
/* Compute the L2 norm of a set of nx vectors */
|
90
|
+
void fvec_norms_L2 (float * __restrict nr,
|
91
|
+
const float * __restrict x,
|
92
|
+
size_t d, size_t nx)
|
93
|
+
{
|
94
|
+
|
95
|
+
#pragma omp parallel for
|
96
|
+
for (size_t i = 0; i < nx; i++) {
|
97
|
+
nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
|
98
|
+
}
|
99
|
+
}
|
100
|
+
|
101
|
+
void fvec_norms_L2sqr (float * __restrict nr,
|
102
|
+
const float * __restrict x,
|
103
|
+
size_t d, size_t nx)
|
104
|
+
{
|
105
|
+
#pragma omp parallel for
|
106
|
+
for (size_t i = 0; i < nx; i++)
|
107
|
+
nr[i] = fvec_norm_L2sqr (x + i * d, d);
|
108
|
+
}
|
109
|
+
|
110
|
+
|
111
|
+
|
112
|
+
void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
113
|
+
{
|
114
|
+
#pragma omp parallel for
|
115
|
+
for (size_t i = 0; i < nx; i++) {
|
116
|
+
float * __restrict xi = x + i * d;
|
117
|
+
|
118
|
+
float nr = fvec_norm_L2sqr (xi, d);
|
119
|
+
|
120
|
+
if (nr > 0) {
|
121
|
+
size_t j;
|
122
|
+
const float inv_nr = 1.0 / sqrtf (nr);
|
123
|
+
for (j = 0; j < d; j++)
|
124
|
+
xi[j] *= inv_nr;
|
125
|
+
}
|
126
|
+
}
|
127
|
+
}
|
128
|
+
|
129
|
+
|
130
|
+
|
131
|
+
|
132
|
+
|
133
|
+
|
134
|
+
|
135
|
+
|
136
|
+
|
137
|
+
|
138
|
+
|
139
|
+
|
140
|
+
/***************************************************************************
|
141
|
+
* KNN functions
|
142
|
+
***************************************************************************/
|
143
|
+
|
144
|
+
|
145
|
+
|
146
|
+
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
147
|
+
static void knn_inner_product_sse (const float * x,
|
148
|
+
const float * y,
|
149
|
+
size_t d, size_t nx, size_t ny,
|
150
|
+
float_minheap_array_t * res)
|
151
|
+
{
|
152
|
+
size_t k = res->k;
|
153
|
+
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
154
|
+
|
155
|
+
check_period *= omp_get_max_threads();
|
156
|
+
|
157
|
+
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
158
|
+
size_t i1 = std::min(i0 + check_period, nx);
|
159
|
+
|
160
|
+
#pragma omp parallel for
|
161
|
+
for (size_t i = i0; i < i1; i++) {
|
162
|
+
const float * x_i = x + i * d;
|
163
|
+
const float * y_j = y;
|
164
|
+
|
165
|
+
float * __restrict simi = res->get_val(i);
|
166
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
167
|
+
|
168
|
+
minheap_heapify (k, simi, idxi);
|
169
|
+
|
170
|
+
for (size_t j = 0; j < ny; j++) {
|
171
|
+
float ip = fvec_inner_product (x_i, y_j, d);
|
172
|
+
|
173
|
+
if (ip > simi[0]) {
|
174
|
+
minheap_pop (k, simi, idxi);
|
175
|
+
minheap_push (k, simi, idxi, ip, j);
|
176
|
+
}
|
177
|
+
y_j += d;
|
178
|
+
}
|
179
|
+
minheap_reorder (k, simi, idxi);
|
180
|
+
}
|
181
|
+
InterruptCallback::check ();
|
182
|
+
}
|
183
|
+
|
184
|
+
}
|
185
|
+
|
186
|
+
static void knn_L2sqr_sse (
|
187
|
+
const float * x,
|
188
|
+
const float * y,
|
189
|
+
size_t d, size_t nx, size_t ny,
|
190
|
+
float_maxheap_array_t * res)
|
191
|
+
{
|
192
|
+
size_t k = res->k;
|
193
|
+
|
194
|
+
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
195
|
+
check_period *= omp_get_max_threads();
|
196
|
+
|
197
|
+
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
198
|
+
size_t i1 = std::min(i0 + check_period, nx);
|
199
|
+
|
200
|
+
#pragma omp parallel for
|
201
|
+
for (size_t i = i0; i < i1; i++) {
|
202
|
+
const float * x_i = x + i * d;
|
203
|
+
const float * y_j = y;
|
204
|
+
size_t j;
|
205
|
+
float * simi = res->get_val(i);
|
206
|
+
int64_t * idxi = res->get_ids (i);
|
207
|
+
|
208
|
+
maxheap_heapify (k, simi, idxi);
|
209
|
+
for (j = 0; j < ny; j++) {
|
210
|
+
float disij = fvec_L2sqr (x_i, y_j, d);
|
211
|
+
|
212
|
+
if (disij < simi[0]) {
|
213
|
+
maxheap_pop (k, simi, idxi);
|
214
|
+
maxheap_push (k, simi, idxi, disij, j);
|
215
|
+
}
|
216
|
+
y_j += d;
|
217
|
+
}
|
218
|
+
maxheap_reorder (k, simi, idxi);
|
219
|
+
}
|
220
|
+
InterruptCallback::check ();
|
221
|
+
}
|
222
|
+
|
223
|
+
}
|
224
|
+
|
225
|
+
|
226
|
+
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
227
|
+
static void knn_inner_product_blas (
|
228
|
+
const float * x,
|
229
|
+
const float * y,
|
230
|
+
size_t d, size_t nx, size_t ny,
|
231
|
+
float_minheap_array_t * res)
|
232
|
+
{
|
233
|
+
res->heapify ();
|
234
|
+
|
235
|
+
// BLAS does not like empty matrices
|
236
|
+
if (nx == 0 || ny == 0) return;
|
237
|
+
|
238
|
+
/* block sizes */
|
239
|
+
const size_t bs_x = 4096, bs_y = 1024;
|
240
|
+
// const size_t bs_x = 16, bs_y = 16;
|
241
|
+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
242
|
+
|
243
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
244
|
+
size_t i1 = i0 + bs_x;
|
245
|
+
if(i1 > nx) i1 = nx;
|
246
|
+
|
247
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
248
|
+
size_t j1 = j0 + bs_y;
|
249
|
+
if (j1 > ny) j1 = ny;
|
250
|
+
/* compute the actual dot products */
|
251
|
+
{
|
252
|
+
float one = 1, zero = 0;
|
253
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
254
|
+
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
255
|
+
y + j0 * d, &di,
|
256
|
+
x + i0 * d, &di, &zero,
|
257
|
+
ip_block.get(), &nyi);
|
258
|
+
}
|
259
|
+
|
260
|
+
/* collect maxima */
|
261
|
+
res->addn (j1 - j0, ip_block.get(), j0, i0, i1 - i0);
|
262
|
+
}
|
263
|
+
InterruptCallback::check ();
|
264
|
+
}
|
265
|
+
res->reorder ();
|
266
|
+
}
|
267
|
+
|
268
|
+
// distance correction is an operator that can be applied to transform
|
269
|
+
// the distances
|
270
|
+
template<class DistanceCorrection>
|
271
|
+
static void knn_L2sqr_blas (const float * x,
|
272
|
+
const float * y,
|
273
|
+
size_t d, size_t nx, size_t ny,
|
274
|
+
float_maxheap_array_t * res,
|
275
|
+
const DistanceCorrection &corr)
|
276
|
+
{
|
277
|
+
res->heapify ();
|
278
|
+
|
279
|
+
// BLAS does not like empty matrices
|
280
|
+
if (nx == 0 || ny == 0) return;
|
281
|
+
|
282
|
+
size_t k = res->k;
|
283
|
+
|
284
|
+
/* block sizes */
|
285
|
+
const size_t bs_x = 4096, bs_y = 1024;
|
286
|
+
// const size_t bs_x = 16, bs_y = 16;
|
287
|
+
float *ip_block = new float[bs_x * bs_y];
|
288
|
+
float *x_norms = new float[nx];
|
289
|
+
float *y_norms = new float[ny];
|
290
|
+
ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
|
291
|
+
|
292
|
+
fvec_norms_L2sqr (x_norms, x, d, nx);
|
293
|
+
fvec_norms_L2sqr (y_norms, y, d, ny);
|
294
|
+
|
295
|
+
|
296
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
297
|
+
size_t i1 = i0 + bs_x;
|
298
|
+
if(i1 > nx) i1 = nx;
|
299
|
+
|
300
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
301
|
+
size_t j1 = j0 + bs_y;
|
302
|
+
if (j1 > ny) j1 = ny;
|
303
|
+
/* compute the actual dot products */
|
304
|
+
{
|
305
|
+
float one = 1, zero = 0;
|
306
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
307
|
+
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
308
|
+
y + j0 * d, &di,
|
309
|
+
x + i0 * d, &di, &zero,
|
310
|
+
ip_block, &nyi);
|
311
|
+
}
|
312
|
+
|
313
|
+
/* collect minima */
|
314
|
+
#pragma omp parallel for
|
315
|
+
for (size_t i = i0; i < i1; i++) {
|
316
|
+
float * __restrict simi = res->get_val(i);
|
317
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
318
|
+
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
319
|
+
|
320
|
+
for (size_t j = j0; j < j1; j++) {
|
321
|
+
float ip = *ip_line++;
|
322
|
+
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
323
|
+
|
324
|
+
// negative values can occur for identical vectors
|
325
|
+
// due to roundoff errors
|
326
|
+
if (dis < 0) dis = 0;
|
327
|
+
|
328
|
+
dis = corr (dis, i, j);
|
329
|
+
|
330
|
+
if (dis < simi[0]) {
|
331
|
+
maxheap_pop (k, simi, idxi);
|
332
|
+
maxheap_push (k, simi, idxi, dis, j);
|
333
|
+
}
|
334
|
+
}
|
335
|
+
}
|
336
|
+
}
|
337
|
+
InterruptCallback::check ();
|
338
|
+
}
|
339
|
+
res->reorder ();
|
340
|
+
|
341
|
+
}
|
342
|
+
|
343
|
+
|
344
|
+
|
345
|
+
|
346
|
+
|
347
|
+
|
348
|
+
|
349
|
+
|
350
|
+
|
351
|
+
/*******************************************************
|
352
|
+
* KNN driver functions
|
353
|
+
*******************************************************/
|
354
|
+
|
355
|
+
int distance_compute_blas_threshold = 20;
|
356
|
+
|
357
|
+
void knn_inner_product (const float * x,
|
358
|
+
const float * y,
|
359
|
+
size_t d, size_t nx, size_t ny,
|
360
|
+
float_minheap_array_t * res)
|
361
|
+
{
|
362
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
363
|
+
knn_inner_product_sse (x, y, d, nx, ny, res);
|
364
|
+
} else {
|
365
|
+
knn_inner_product_blas (x, y, d, nx, ny, res);
|
366
|
+
}
|
367
|
+
}
|
368
|
+
|
369
|
+
|
370
|
+
|
371
|
+
struct NopDistanceCorrection {
|
372
|
+
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
|
373
|
+
return dis;
|
374
|
+
}
|
375
|
+
};
|
376
|
+
|
377
|
+
void knn_L2sqr (const float * x,
|
378
|
+
const float * y,
|
379
|
+
size_t d, size_t nx, size_t ny,
|
380
|
+
float_maxheap_array_t * res)
|
381
|
+
{
|
382
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
383
|
+
knn_L2sqr_sse (x, y, d, nx, ny, res);
|
384
|
+
} else {
|
385
|
+
NopDistanceCorrection nop;
|
386
|
+
knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
|
387
|
+
}
|
388
|
+
}
|
389
|
+
|
390
|
+
struct BaseShiftDistanceCorrection {
|
391
|
+
const float *base_shift;
|
392
|
+
float operator()(float dis, size_t /*qno*/, size_t bno) const {
|
393
|
+
return dis - base_shift[bno];
|
394
|
+
}
|
395
|
+
};
|
396
|
+
|
397
|
+
void knn_L2sqr_base_shift (
|
398
|
+
const float * x,
|
399
|
+
const float * y,
|
400
|
+
size_t d, size_t nx, size_t ny,
|
401
|
+
float_maxheap_array_t * res,
|
402
|
+
const float *base_shift)
|
403
|
+
{
|
404
|
+
BaseShiftDistanceCorrection corr = {base_shift};
|
405
|
+
knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
|
406
|
+
}
|
407
|
+
|
408
|
+
|
409
|
+
|
410
|
+
/***************************************************************************
|
411
|
+
* compute a subset of distances
|
412
|
+
***************************************************************************/
|
413
|
+
|
414
|
+
/* compute the inner product between x and a subset y of ny vectors,
|
415
|
+
whose indices are given by idy. */
|
416
|
+
void fvec_inner_products_by_idx (float * __restrict ip,
|
417
|
+
const float * x,
|
418
|
+
const float * y,
|
419
|
+
const int64_t * __restrict ids, /* for y vecs */
|
420
|
+
size_t d, size_t nx, size_t ny)
|
421
|
+
{
|
422
|
+
#pragma omp parallel for
|
423
|
+
for (size_t j = 0; j < nx; j++) {
|
424
|
+
const int64_t * __restrict idsj = ids + j * ny;
|
425
|
+
const float * xj = x + j * d;
|
426
|
+
float * __restrict ipj = ip + j * ny;
|
427
|
+
for (size_t i = 0; i < ny; i++) {
|
428
|
+
if (idsj[i] < 0)
|
429
|
+
continue;
|
430
|
+
ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
|
431
|
+
}
|
432
|
+
}
|
433
|
+
}
|
434
|
+
|
435
|
+
|
436
|
+
|
437
|
+
/* compute the inner product between x and a subset y of ny vectors,
|
438
|
+
whose indices are given by idy. */
|
439
|
+
void fvec_L2sqr_by_idx (float * __restrict dis,
|
440
|
+
const float * x,
|
441
|
+
const float * y,
|
442
|
+
const int64_t * __restrict ids, /* ids of y vecs */
|
443
|
+
size_t d, size_t nx, size_t ny)
|
444
|
+
{
|
445
|
+
#pragma omp parallel for
|
446
|
+
for (size_t j = 0; j < nx; j++) {
|
447
|
+
const int64_t * __restrict idsj = ids + j * ny;
|
448
|
+
const float * xj = x + j * d;
|
449
|
+
float * __restrict disj = dis + j * ny;
|
450
|
+
for (size_t i = 0; i < ny; i++) {
|
451
|
+
if (idsj[i] < 0)
|
452
|
+
continue;
|
453
|
+
disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
|
454
|
+
}
|
455
|
+
}
|
456
|
+
}
|
457
|
+
|
458
|
+
void pairwise_indexed_L2sqr (
|
459
|
+
size_t d, size_t n,
|
460
|
+
const float * x, const int64_t *ix,
|
461
|
+
const float * y, const int64_t *iy,
|
462
|
+
float *dis)
|
463
|
+
{
|
464
|
+
#pragma omp parallel for
|
465
|
+
for (size_t j = 0; j < n; j++) {
|
466
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
467
|
+
dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
|
468
|
+
}
|
469
|
+
}
|
470
|
+
}
|
471
|
+
|
472
|
+
void pairwise_indexed_inner_product (
|
473
|
+
size_t d, size_t n,
|
474
|
+
const float * x, const int64_t *ix,
|
475
|
+
const float * y, const int64_t *iy,
|
476
|
+
float *dis)
|
477
|
+
{
|
478
|
+
#pragma omp parallel for
|
479
|
+
for (size_t j = 0; j < n; j++) {
|
480
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
481
|
+
dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
|
482
|
+
}
|
483
|
+
}
|
484
|
+
}
|
485
|
+
|
486
|
+
|
487
|
+
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
488
|
+
indexed by ids. May be useful for re-ranking a pre-selected vector list */
|
489
|
+
void knn_inner_products_by_idx (const float * x,
|
490
|
+
const float * y,
|
491
|
+
const int64_t * ids,
|
492
|
+
size_t d, size_t nx, size_t ny,
|
493
|
+
float_minheap_array_t * res)
|
494
|
+
{
|
495
|
+
size_t k = res->k;
|
496
|
+
|
497
|
+
#pragma omp parallel for
|
498
|
+
for (size_t i = 0; i < nx; i++) {
|
499
|
+
const float * x_ = x + i * d;
|
500
|
+
const int64_t * idsi = ids + i * ny;
|
501
|
+
size_t j;
|
502
|
+
float * __restrict simi = res->get_val(i);
|
503
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
504
|
+
minheap_heapify (k, simi, idxi);
|
505
|
+
|
506
|
+
for (j = 0; j < ny; j++) {
|
507
|
+
if (idsi[j] < 0) break;
|
508
|
+
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
|
509
|
+
|
510
|
+
if (ip > simi[0]) {
|
511
|
+
minheap_pop (k, simi, idxi);
|
512
|
+
minheap_push (k, simi, idxi, ip, idsi[j]);
|
513
|
+
}
|
514
|
+
}
|
515
|
+
minheap_reorder (k, simi, idxi);
|
516
|
+
}
|
517
|
+
|
518
|
+
}
|
519
|
+
|
520
|
+
void knn_L2sqr_by_idx (const float * x,
|
521
|
+
const float * y,
|
522
|
+
const int64_t * __restrict ids,
|
523
|
+
size_t d, size_t nx, size_t ny,
|
524
|
+
float_maxheap_array_t * res)
|
525
|
+
{
|
526
|
+
size_t k = res->k;
|
527
|
+
|
528
|
+
#pragma omp parallel for
|
529
|
+
for (size_t i = 0; i < nx; i++) {
|
530
|
+
const float * x_ = x + i * d;
|
531
|
+
const int64_t * __restrict idsi = ids + i * ny;
|
532
|
+
float * __restrict simi = res->get_val(i);
|
533
|
+
int64_t * __restrict idxi = res->get_ids (i);
|
534
|
+
maxheap_heapify (res->k, simi, idxi);
|
535
|
+
for (size_t j = 0; j < ny; j++) {
|
536
|
+
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
|
537
|
+
|
538
|
+
if (disij < simi[0]) {
|
539
|
+
maxheap_pop (k, simi, idxi);
|
540
|
+
maxheap_push (k, simi, idxi, disij, idsi[j]);
|
541
|
+
}
|
542
|
+
}
|
543
|
+
maxheap_reorder (res->k, simi, idxi);
|
544
|
+
}
|
545
|
+
|
546
|
+
}
|
547
|
+
|
548
|
+
|
549
|
+
|
550
|
+
|
551
|
+
|
552
|
+
/***************************************************************************
|
553
|
+
* Range search
|
554
|
+
***************************************************************************/
|
555
|
+
|
556
|
+
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
557
|
+
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
558
|
+
*/
|
559
|
+
template <bool compute_l2>
|
560
|
+
static void range_search_blas (
|
561
|
+
const float * x,
|
562
|
+
const float * y,
|
563
|
+
size_t d, size_t nx, size_t ny,
|
564
|
+
float radius,
|
565
|
+
RangeSearchResult *result)
|
566
|
+
{
|
567
|
+
|
568
|
+
// BLAS does not like empty matrices
|
569
|
+
if (nx == 0 || ny == 0) return;
|
570
|
+
|
571
|
+
/* block sizes */
|
572
|
+
const size_t bs_x = 4096, bs_y = 1024;
|
573
|
+
// const size_t bs_x = 16, bs_y = 16;
|
574
|
+
float *ip_block = new float[bs_x * bs_y];
|
575
|
+
ScopeDeleter<float> del0(ip_block);
|
576
|
+
|
577
|
+
float *x_norms = nullptr, *y_norms = nullptr;
|
578
|
+
ScopeDeleter<float> del1, del2;
|
579
|
+
if (compute_l2) {
|
580
|
+
x_norms = new float[nx];
|
581
|
+
del1.set (x_norms);
|
582
|
+
fvec_norms_L2sqr (x_norms, x, d, nx);
|
583
|
+
|
584
|
+
y_norms = new float[ny];
|
585
|
+
del2.set (y_norms);
|
586
|
+
fvec_norms_L2sqr (y_norms, y, d, ny);
|
587
|
+
}
|
588
|
+
|
589
|
+
std::vector <RangeSearchPartialResult *> partial_results;
|
590
|
+
|
591
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
592
|
+
size_t j1 = j0 + bs_y;
|
593
|
+
if (j1 > ny) j1 = ny;
|
594
|
+
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
|
595
|
+
partial_results.push_back (pres);
|
596
|
+
|
597
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
598
|
+
size_t i1 = i0 + bs_x;
|
599
|
+
if(i1 > nx) i1 = nx;
|
600
|
+
|
601
|
+
/* compute the actual dot products */
|
602
|
+
{
|
603
|
+
float one = 1, zero = 0;
|
604
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
605
|
+
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
606
|
+
y + j0 * d, &di,
|
607
|
+
x + i0 * d, &di, &zero,
|
608
|
+
ip_block, &nyi);
|
609
|
+
}
|
610
|
+
|
611
|
+
|
612
|
+
for (size_t i = i0; i < i1; i++) {
|
613
|
+
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
614
|
+
|
615
|
+
RangeQueryResult & qres = pres->new_result (i);
|
616
|
+
|
617
|
+
for (size_t j = j0; j < j1; j++) {
|
618
|
+
float ip = *ip_line++;
|
619
|
+
if (compute_l2) {
|
620
|
+
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
621
|
+
if (dis < radius) {
|
622
|
+
qres.add (dis, j);
|
623
|
+
}
|
624
|
+
} else {
|
625
|
+
if (ip > radius) {
|
626
|
+
qres.add (ip, j);
|
627
|
+
}
|
628
|
+
}
|
629
|
+
}
|
630
|
+
}
|
631
|
+
}
|
632
|
+
InterruptCallback::check ();
|
633
|
+
}
|
634
|
+
|
635
|
+
RangeSearchPartialResult::merge (partial_results);
|
636
|
+
}
|
637
|
+
|
638
|
+
|
639
|
+
template <bool compute_l2>
|
640
|
+
static void range_search_sse (const float * x,
|
641
|
+
const float * y,
|
642
|
+
size_t d, size_t nx, size_t ny,
|
643
|
+
float radius,
|
644
|
+
RangeSearchResult *res)
|
645
|
+
{
|
646
|
+
FAISS_THROW_IF_NOT (d % 4 == 0);
|
647
|
+
|
648
|
+
#pragma omp parallel
|
649
|
+
{
|
650
|
+
RangeSearchPartialResult pres (res);
|
651
|
+
|
652
|
+
#pragma omp for
|
653
|
+
for (size_t i = 0; i < nx; i++) {
|
654
|
+
const float * x_ = x + i * d;
|
655
|
+
const float * y_ = y;
|
656
|
+
size_t j;
|
657
|
+
|
658
|
+
RangeQueryResult & qres = pres.new_result (i);
|
659
|
+
|
660
|
+
for (j = 0; j < ny; j++) {
|
661
|
+
if (compute_l2) {
|
662
|
+
float disij = fvec_L2sqr (x_, y_, d);
|
663
|
+
if (disij < radius) {
|
664
|
+
qres.add (disij, j);
|
665
|
+
}
|
666
|
+
} else {
|
667
|
+
float ip = fvec_inner_product (x_, y_, d);
|
668
|
+
if (ip > radius) {
|
669
|
+
qres.add (ip, j);
|
670
|
+
}
|
671
|
+
}
|
672
|
+
y_ += d;
|
673
|
+
}
|
674
|
+
|
675
|
+
}
|
676
|
+
pres.finalize ();
|
677
|
+
}
|
678
|
+
|
679
|
+
// check just at the end because the use case is typically just
|
680
|
+
// when the nb of queries is low.
|
681
|
+
InterruptCallback::check();
|
682
|
+
}
|
683
|
+
|
684
|
+
|
685
|
+
|
686
|
+
|
687
|
+
|
688
|
+
void range_search_L2sqr (
|
689
|
+
const float * x,
|
690
|
+
const float * y,
|
691
|
+
size_t d, size_t nx, size_t ny,
|
692
|
+
float radius,
|
693
|
+
RangeSearchResult *res)
|
694
|
+
{
|
695
|
+
|
696
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
697
|
+
range_search_sse<true> (x, y, d, nx, ny, radius, res);
|
698
|
+
} else {
|
699
|
+
range_search_blas<true> (x, y, d, nx, ny, radius, res);
|
700
|
+
}
|
701
|
+
}
|
702
|
+
|
703
|
+
void range_search_inner_product (
|
704
|
+
const float * x,
|
705
|
+
const float * y,
|
706
|
+
size_t d, size_t nx, size_t ny,
|
707
|
+
float radius,
|
708
|
+
RangeSearchResult *res)
|
709
|
+
{
|
710
|
+
|
711
|
+
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
|
712
|
+
range_search_sse<false> (x, y, d, nx, ny, radius, res);
|
713
|
+
} else {
|
714
|
+
range_search_blas<false> (x, y, d, nx, ny, radius, res);
|
715
|
+
}
|
716
|
+
}
|
717
|
+
|
718
|
+
|
719
|
+
void pairwise_L2sqr (int64_t d,
|
720
|
+
int64_t nq, const float *xq,
|
721
|
+
int64_t nb, const float *xb,
|
722
|
+
float *dis,
|
723
|
+
int64_t ldq, int64_t ldb, int64_t ldd)
|
724
|
+
{
|
725
|
+
if (nq == 0 || nb == 0) return;
|
726
|
+
if (ldq == -1) ldq = d;
|
727
|
+
if (ldb == -1) ldb = d;
|
728
|
+
if (ldd == -1) ldd = nb;
|
729
|
+
|
730
|
+
// store in beginning of distance matrix to avoid malloc
|
731
|
+
float *b_norms = dis;
|
732
|
+
|
733
|
+
#pragma omp parallel for
|
734
|
+
for (int64_t i = 0; i < nb; i++)
|
735
|
+
b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
|
736
|
+
|
737
|
+
#pragma omp parallel for
|
738
|
+
for (int64_t i = 1; i < nq; i++) {
|
739
|
+
float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
|
740
|
+
for (int64_t j = 0; j < nb; j++)
|
741
|
+
dis[i * ldd + j] = q_norm + b_norms [j];
|
742
|
+
}
|
743
|
+
|
744
|
+
{
|
745
|
+
float q_norm = fvec_norm_L2sqr (xq, d);
|
746
|
+
for (int64_t j = 0; j < nb; j++)
|
747
|
+
dis[j] += q_norm;
|
748
|
+
}
|
749
|
+
|
750
|
+
{
|
751
|
+
FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
|
752
|
+
float one = 1.0, minus_2 = -2.0;
|
753
|
+
|
754
|
+
sgemm_ ("Transposed", "Not transposed",
|
755
|
+
&nbi, &nqi, &di,
|
756
|
+
&minus_2,
|
757
|
+
xb, &ldbi,
|
758
|
+
xq, &ldqi,
|
759
|
+
&one, dis, &lddi);
|
760
|
+
}
|
761
|
+
|
762
|
+
}
|
763
|
+
|
764
|
+
|
765
|
+
} // namespace faiss
|