faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -5,9 +5,8 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
- #include <faiss/utils/WorkerThread.h>
10
8
  #include <faiss/impl/FaissAssert.h>
9
+ #include <faiss/utils/WorkerThread.h>
11
10
  #include <exception>
12
11
 
13
12
  namespace faiss {
@@ -15,112 +14,104 @@ namespace faiss {
15
14
  namespace {
16
15
 
17
16
  // Captures any exceptions thrown by the lambda and returns them via the promise
18
- void runCallback(std::function<void()>& fn,
19
- std::promise<bool>& promise) {
20
- try {
21
- fn();
22
- promise.set_value(true);
23
- } catch (...) {
24
- promise.set_exception(std::current_exception());
25
- }
17
+ void runCallback(std::function<void()>& fn, std::promise<bool>& promise) {
18
+ try {
19
+ fn();
20
+ promise.set_value(true);
21
+ } catch (...) {
22
+ promise.set_exception(std::current_exception());
23
+ }
26
24
  }
27
25
 
28
26
  } // namespace
29
27
 
30
- WorkerThread::WorkerThread() :
31
- wantStop_(false) {
32
- startThread();
28
+ WorkerThread::WorkerThread() : wantStop_(false) {
29
+ startThread();
33
30
 
34
- // Make sure that the thread has started before continuing
35
- add([](){}).get();
31
+ // Make sure that the thread has started before continuing
32
+ add([]() {}).get();
36
33
  }
37
34
 
38
35
  WorkerThread::~WorkerThread() {
39
- stop();
40
- waitForThreadExit();
36
+ stop();
37
+ waitForThreadExit();
41
38
  }
42
39
 
43
- void
44
- WorkerThread::startThread() {
45
- thread_ = std::thread([this](){ threadMain(); });
40
+ void WorkerThread::startThread() {
41
+ thread_ = std::thread([this]() { threadMain(); });
46
42
  }
47
43
 
48
- void
49
- WorkerThread::stop() {
50
- std::lock_guard<std::mutex> guard(mutex_);
44
+ void WorkerThread::stop() {
45
+ std::lock_guard<std::mutex> guard(mutex_);
51
46
 
52
- wantStop_ = true;
53
- monitor_.notify_one();
47
+ wantStop_ = true;
48
+ monitor_.notify_one();
54
49
  }
55
50
 
56
- std::future<bool>
57
- WorkerThread::add(std::function<void()> f) {
58
- std::lock_guard<std::mutex> guard(mutex_);
51
+ std::future<bool> WorkerThread::add(std::function<void()> f) {
52
+ std::lock_guard<std::mutex> guard(mutex_);
59
53
 
60
- if (wantStop_) {
61
- // The timer thread has been stopped, or we want to stop; we can't
62
- // schedule anything else
63
- std::promise<bool> p;
64
- auto fut = p.get_future();
54
+ if (wantStop_) {
55
+ // The timer thread has been stopped, or we want to stop; we can't
56
+ // schedule anything else
57
+ std::promise<bool> p;
58
+ auto fut = p.get_future();
65
59
 
66
- // did not execute
67
- p.set_value(false);
68
- return fut;
69
- }
60
+ // did not execute
61
+ p.set_value(false);
62
+ return fut;
63
+ }
70
64
 
71
- auto pr = std::promise<bool>();
72
- auto fut = pr.get_future();
65
+ auto pr = std::promise<bool>();
66
+ auto fut = pr.get_future();
73
67
 
74
- queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
68
+ queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
75
69
 
76
- // Wake up our thread
77
- monitor_.notify_one();
78
- return fut;
70
+ // Wake up our thread
71
+ monitor_.notify_one();
72
+ return fut;
79
73
  }
80
74
 
81
- void
82
- WorkerThread::threadMain() {
83
- threadLoop();
75
+ void WorkerThread::threadMain() {
76
+ threadLoop();
84
77
 
85
- // Call all pending tasks
86
- FAISS_ASSERT(wantStop_);
78
+ // Call all pending tasks
79
+ FAISS_ASSERT(wantStop_);
87
80
 
88
- // flush all pending operations
89
- for (auto& f : queue_) {
90
- runCallback(f.first, f.second);
91
- }
81
+ // flush all pending operations
82
+ for (auto& f : queue_) {
83
+ runCallback(f.first, f.second);
84
+ }
92
85
  }
93
86
 
94
- void
95
- WorkerThread::threadLoop() {
96
- while (true) {
97
- std::pair<std::function<void()>, std::promise<bool>> data;
87
+ void WorkerThread::threadLoop() {
88
+ while (true) {
89
+ std::pair<std::function<void()>, std::promise<bool>> data;
98
90
 
99
- {
100
- std::unique_lock<std::mutex> lock(mutex_);
91
+ {
92
+ std::unique_lock<std::mutex> lock(mutex_);
101
93
 
102
- while (!wantStop_ && queue_.empty()) {
103
- monitor_.wait(lock);
104
- }
94
+ while (!wantStop_ && queue_.empty()) {
95
+ monitor_.wait(lock);
96
+ }
105
97
 
106
- if (wantStop_) {
107
- return;
108
- }
98
+ if (wantStop_) {
99
+ return;
100
+ }
109
101
 
110
- data = std::move(queue_.front());
111
- queue_.pop_front();
112
- }
102
+ data = std::move(queue_.front());
103
+ queue_.pop_front();
104
+ }
113
105
 
114
- runCallback(data.first, data.second);
115
- }
106
+ runCallback(data.first, data.second);
107
+ }
116
108
  }
117
109
 
118
- void
119
- WorkerThread::waitForThreadExit() {
120
- try {
121
- thread_.join();
122
- } catch (...) {
123
- }
110
+ void WorkerThread::waitForThreadExit() {
111
+ try {
112
+ thread_.join();
113
+ } catch (...) {
114
+ }
124
115
  }
125
116
 
126
- } // namespace
117
+ } // namespace faiss
@@ -5,57 +5,56 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <condition_variable>
12
- #include <future>
13
11
  #include <deque>
12
+ #include <future>
14
13
  #include <thread>
15
14
 
16
15
  namespace faiss {
17
16
 
18
17
  class WorkerThread {
19
- public:
20
- WorkerThread();
18
+ public:
19
+ WorkerThread();
21
20
 
22
- /// Stops and waits for the worker thread to exit, flushing all
23
- /// pending lambdas
24
- ~WorkerThread();
21
+ /// Stops and waits for the worker thread to exit, flushing all
22
+ /// pending lambdas
23
+ ~WorkerThread();
25
24
 
26
- /// Request that the worker thread stop itself
27
- void stop();
25
+ /// Request that the worker thread stop itself
26
+ void stop();
28
27
 
29
- /// Blocking waits in the current thread for the worker thread to
30
- /// stop
31
- void waitForThreadExit();
28
+ /// Blocking waits in the current thread for the worker thread to
29
+ /// stop
30
+ void waitForThreadExit();
32
31
 
33
- /// Adds a lambda to run on the worker thread; returns a future that
34
- /// can be used to block on its completion.
35
- /// Future status is `true` if the lambda was run in the worker
36
- /// thread; `false` if it was not run, because the worker thread is
37
- /// exiting or has exited.
38
- std::future<bool> add(std::function<void()> f);
32
+ /// Adds a lambda to run on the worker thread; returns a future that
33
+ /// can be used to block on its completion.
34
+ /// Future status is `true` if the lambda was run in the worker
35
+ /// thread; `false` if it was not run, because the worker thread is
36
+ /// exiting or has exited.
37
+ std::future<bool> add(std::function<void()> f);
39
38
 
40
- private:
41
- void startThread();
42
- void threadMain();
43
- void threadLoop();
39
+ private:
40
+ void startThread();
41
+ void threadMain();
42
+ void threadLoop();
44
43
 
45
- /// Thread that all queued lambdas are run on
46
- std::thread thread_;
44
+ /// Thread that all queued lambdas are run on
45
+ std::thread thread_;
47
46
 
48
- /// Mutex for the queue and exit status
49
- std::mutex mutex_;
47
+ /// Mutex for the queue and exit status
48
+ std::mutex mutex_;
50
49
 
51
- /// Monitor for the exit status and the queue
52
- std::condition_variable monitor_;
50
+ /// Monitor for the exit status and the queue
51
+ std::condition_variable monitor_;
53
52
 
54
- /// Whether or not we want the thread to exit
55
- bool wantStop_;
53
+ /// Whether or not we want the thread to exit
54
+ bool wantStop_;
56
55
 
57
- /// Queue of pending lambdas to call
58
- std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
56
+ /// Queue of pending lambdas to call
57
+ std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
59
58
  };
60
59
 
61
- } // namespace
60
+ } // namespace faiss
@@ -10,10 +10,10 @@
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
12
  #include <algorithm>
13
- #include <cstdio>
14
13
  #include <cassert>
15
- #include <cstring>
16
14
  #include <cmath>
15
+ #include <cstdio>
16
+ #include <cstring>
17
17
 
18
18
  #include <omp.h>
19
19
 
@@ -21,186 +21,151 @@
21
21
  #include <faiss/impl/FaissAssert.h>
22
22
  #include <faiss/impl/ResultHandler.h>
23
23
 
24
-
25
-
26
24
  #ifndef FINTEGER
27
25
  #define FINTEGER long
28
26
  #endif
29
27
 
30
-
31
28
  extern "C" {
32
29
 
33
30
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
34
31
 
35
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
36
- n, FINTEGER *k, const float *alpha, const float *a,
37
- FINTEGER *lda, const float *b, FINTEGER *
38
- ldb, float *beta, float *c, FINTEGER *ldc);
39
-
40
-
32
+ int sgemm_(
33
+ const char* transa,
34
+ const char* transb,
35
+ FINTEGER* m,
36
+ FINTEGER* n,
37
+ FINTEGER* k,
38
+ const float* alpha,
39
+ const float* a,
40
+ FINTEGER* lda,
41
+ const float* b,
42
+ FINTEGER* ldb,
43
+ float* beta,
44
+ float* c,
45
+ FINTEGER* ldc);
41
46
  }
42
47
 
43
-
44
48
  namespace faiss {
45
49
 
46
-
47
-
48
50
  /***************************************************************************
49
51
  * Matrix/vector ops
50
52
  ***************************************************************************/
51
53
 
52
-
53
-
54
-
55
54
  /* Compute the L2 norm of a set of nx vectors */
56
- void fvec_norms_L2 (float * __restrict nr,
57
- const float * __restrict x,
58
- size_t d, size_t nx)
59
- {
60
-
55
+ void fvec_norms_L2(
56
+ float* __restrict nr,
57
+ const float* __restrict x,
58
+ size_t d,
59
+ size_t nx) {
61
60
  #pragma omp parallel for
62
61
  for (int64_t i = 0; i < nx; i++) {
63
- nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
62
+ nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
64
63
  }
65
64
  }
66
65
 
67
- void fvec_norms_L2sqr (float * __restrict nr,
68
- const float * __restrict x,
69
- size_t d, size_t nx)
70
- {
66
+ void fvec_norms_L2sqr(
67
+ float* __restrict nr,
68
+ const float* __restrict x,
69
+ size_t d,
70
+ size_t nx) {
71
71
  #pragma omp parallel for
72
72
  for (int64_t i = 0; i < nx; i++)
73
- nr[i] = fvec_norm_L2sqr (x + i * d, d);
73
+ nr[i] = fvec_norm_L2sqr(x + i * d, d);
74
74
  }
75
75
 
76
-
77
-
78
- void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
79
- {
76
+ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
80
77
  #pragma omp parallel for
81
78
  for (int64_t i = 0; i < nx; i++) {
82
- float * __restrict xi = x + i * d;
79
+ float* __restrict xi = x + i * d;
83
80
 
84
- float nr = fvec_norm_L2sqr (xi, d);
81
+ float nr = fvec_norm_L2sqr(xi, d);
85
82
 
86
83
  if (nr > 0) {
87
84
  size_t j;
88
- const float inv_nr = 1.0 / sqrtf (nr);
85
+ const float inv_nr = 1.0 / sqrtf(nr);
89
86
  for (j = 0; j < d; j++)
90
87
  xi[j] *= inv_nr;
91
88
  }
92
89
  }
93
90
  }
94
91
 
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
-
106
92
  /***************************************************************************
107
93
  * KNN functions
108
94
  ***************************************************************************/
109
95
 
110
96
  namespace {
111
97
 
112
-
113
-
114
98
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
115
- template<class ResultHandler>
116
- void exhaustive_inner_product_seq (
117
- const float * x,
118
- const float * y,
119
- size_t d, size_t nx, size_t ny,
120
- ResultHandler &res)
121
- {
122
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
123
-
124
- check_period *= omp_get_max_threads();
125
-
99
+ template <class ResultHandler>
100
+ void exhaustive_inner_product_seq(
101
+ const float* x,
102
+ const float* y,
103
+ size_t d,
104
+ size_t nx,
105
+ size_t ny,
106
+ ResultHandler& res) {
126
107
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
127
108
 
128
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
129
- size_t i1 = std::min(i0 + check_period, nx);
130
-
131
109
  #pragma omp parallel
132
- {
133
- SingleResultHandler resi(res);
110
+ {
111
+ SingleResultHandler resi(res);
134
112
  #pragma omp for
135
- for (int64_t i = i0; i < i1; i++) {
136
- const float * x_i = x + i * d;
137
- const float * y_j = y;
113
+ for (int64_t i = 0; i < nx; i++) {
114
+ const float* x_i = x + i * d;
115
+ const float* y_j = y;
138
116
 
139
- resi.begin(i);
117
+ resi.begin(i);
140
118
 
141
- for (size_t j = 0; j < ny; j++) {
142
- float ip = fvec_inner_product (x_i, y_j, d);
143
- resi.add_result(ip, j);
144
- y_j += d;
145
- }
146
- resi.end();
119
+ for (size_t j = 0; j < ny; j++) {
120
+ float ip = fvec_inner_product(x_i, y_j, d);
121
+ resi.add_result(ip, j);
122
+ y_j += d;
147
123
  }
124
+ resi.end();
148
125
  }
149
- InterruptCallback::check ();
150
126
  }
151
-
152
127
  }
153
128
 
154
- template<class ResultHandler>
155
- void exhaustive_L2sqr_seq (
156
- const float * x,
157
- const float * y,
158
- size_t d, size_t nx, size_t ny,
159
- ResultHandler & res)
160
- {
161
-
162
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
163
- check_period *= omp_get_max_threads();
129
+ template <class ResultHandler>
130
+ void exhaustive_L2sqr_seq(
131
+ const float* x,
132
+ const float* y,
133
+ size_t d,
134
+ size_t nx,
135
+ size_t ny,
136
+ ResultHandler& res) {
164
137
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
165
138
 
166
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
167
- size_t i1 = std::min(i0 + check_period, nx);
168
-
169
139
  #pragma omp parallel
170
- {
171
- SingleResultHandler resi(res);
140
+ {
141
+ SingleResultHandler resi(res);
172
142
  #pragma omp for
173
- for (int64_t i = i0; i < i1; i++) {
174
- const float * x_i = x + i * d;
175
- const float * y_j = y;
176
- resi.begin(i);
177
- for (size_t j = 0; j < ny; j++) {
178
- float disij = fvec_L2sqr (x_i, y_j, d);
179
- resi.add_result(disij, j);
180
- y_j += d;
181
- }
182
- resi.end();
143
+ for (int64_t i = 0; i < nx; i++) {
144
+ const float* x_i = x + i * d;
145
+ const float* y_j = y;
146
+ resi.begin(i);
147
+ for (size_t j = 0; j < ny; j++) {
148
+ float disij = fvec_L2sqr(x_i, y_j, d);
149
+ resi.add_result(disij, j);
150
+ y_j += d;
183
151
  }
152
+ resi.end();
184
153
  }
185
- InterruptCallback::check ();
186
154
  }
187
-
188
- };
189
-
190
-
191
-
192
-
155
+ }
193
156
 
194
157
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
195
- template<class ResultHandler>
196
- void exhaustive_inner_product_blas (
197
- const float * x,
198
- const float * y,
199
- size_t d, size_t nx, size_t ny,
200
- ResultHandler & res)
201
- {
158
+ template <class ResultHandler>
159
+ void exhaustive_inner_product_blas(
160
+ const float* x,
161
+ const float* y,
162
+ size_t d,
163
+ size_t nx,
164
+ size_t ny,
165
+ ResultHandler& res) {
202
166
  // BLAS does not like empty matrices
203
- if (nx == 0 || ny == 0) return;
167
+ if (nx == 0 || ny == 0)
168
+ return;
204
169
 
205
170
  /* block sizes */
206
171
  const size_t bs_x = distance_compute_blas_query_bs;
@@ -209,86 +174,105 @@ void exhaustive_inner_product_blas (
209
174
 
210
175
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
211
176
  size_t i1 = i0 + bs_x;
212
- if(i1 > nx) i1 = nx;
177
+ if (i1 > nx)
178
+ i1 = nx;
213
179
 
214
180
  res.begin_multiple(i0, i1);
215
181
 
216
182
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
217
183
  size_t j1 = j0 + bs_y;
218
- if (j1 > ny) j1 = ny;
184
+ if (j1 > ny)
185
+ j1 = ny;
219
186
  /* compute the actual dot products */
220
187
  {
221
188
  float one = 1, zero = 0;
222
189
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
223
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
224
- y + j0 * d, &di,
225
- x + i0 * d, &di, &zero,
226
- ip_block.get(), &nyi);
190
+ sgemm_("Transpose",
191
+ "Not transpose",
192
+ &nyi,
193
+ &nxi,
194
+ &di,
195
+ &one,
196
+ y + j0 * d,
197
+ &di,
198
+ x + i0 * d,
199
+ &di,
200
+ &zero,
201
+ ip_block.get(),
202
+ &nyi);
227
203
  }
228
204
 
229
205
  res.add_results(j0, j1, ip_block.get());
230
-
231
206
  }
232
207
  res.end_multiple();
233
- InterruptCallback::check ();
234
-
208
+ InterruptCallback::check();
235
209
  }
236
210
  }
237
211
 
238
-
239
-
240
-
241
212
  // distance correction is an operator that can be applied to transform
242
213
  // the distances
243
- template<class ResultHandler>
244
- void exhaustive_L2sqr_blas (
245
- const float * x,
246
- const float * y,
247
- size_t d, size_t nx, size_t ny,
248
- ResultHandler & res,
249
- const float *y_norms = nullptr)
250
- {
214
+ template <class ResultHandler>
215
+ void exhaustive_L2sqr_blas(
216
+ const float* x,
217
+ const float* y,
218
+ size_t d,
219
+ size_t nx,
220
+ size_t ny,
221
+ ResultHandler& res,
222
+ const float* y_norms = nullptr) {
251
223
  // BLAS does not like empty matrices
252
- if (nx == 0 || ny == 0) return;
224
+ if (nx == 0 || ny == 0)
225
+ return;
253
226
 
254
227
  /* block sizes */
255
228
  const size_t bs_x = distance_compute_blas_query_bs;
256
229
  const size_t bs_y = distance_compute_blas_database_bs;
257
230
  // const size_t bs_x = 16, bs_y = 16;
258
- std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
259
- std::unique_ptr<float []> x_norms(new float[nx]);
260
- std::unique_ptr<float []> del2;
231
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
232
+ std::unique_ptr<float[]> x_norms(new float[nx]);
233
+ std::unique_ptr<float[]> del2;
261
234
 
262
- fvec_norms_L2sqr (x_norms.get(), x, d, nx);
235
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
263
236
 
264
237
  if (!y_norms) {
265
- float *y_norms2 = new float[ny];
238
+ float* y_norms2 = new float[ny];
266
239
  del2.reset(y_norms2);
267
- fvec_norms_L2sqr (y_norms2, y, d, ny);
240
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
268
241
  y_norms = y_norms2;
269
242
  }
270
243
 
271
244
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
272
245
  size_t i1 = i0 + bs_x;
273
- if(i1 > nx) i1 = nx;
246
+ if (i1 > nx)
247
+ i1 = nx;
274
248
 
275
249
  res.begin_multiple(i0, i1);
276
250
 
277
251
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
278
252
  size_t j1 = j0 + bs_y;
279
- if (j1 > ny) j1 = ny;
253
+ if (j1 > ny)
254
+ j1 = ny;
280
255
  /* compute the actual dot products */
281
256
  {
282
257
  float one = 1, zero = 0;
283
258
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
284
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
285
- y + j0 * d, &di,
286
- x + i0 * d, &di, &zero,
287
- ip_block.get(), &nyi);
259
+ sgemm_("Transpose",
260
+ "Not transpose",
261
+ &nyi,
262
+ &nxi,
263
+ &di,
264
+ &one,
265
+ y + j0 * d,
266
+ &di,
267
+ x + i0 * d,
268
+ &di,
269
+ &zero,
270
+ ip_block.get(),
271
+ &nyi);
288
272
  }
289
-
273
+ #pragma omp parallel for
290
274
  for (int64_t i = i0; i < i1; i++) {
291
- float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
275
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
292
276
 
293
277
  for (size_t j = j0; j < j1; j++) {
294
278
  float ip = *ip_line;
@@ -296,7 +280,8 @@ void exhaustive_L2sqr_blas (
296
280
 
297
281
  // negative values can occur for identical vectors
298
282
  // due to roundoff errors
299
- if (dis < 0) dis = 0;
283
+ if (dis < 0)
284
+ dis = 0;
300
285
 
301
286
  *ip_line = dis;
302
287
  ip_line++;
@@ -305,18 +290,12 @@ void exhaustive_L2sqr_blas (
305
290
  res.add_results(j0, j1, ip_block.get());
306
291
  }
307
292
  res.end_multiple();
308
- InterruptCallback::check ();
293
+ InterruptCallback::check();
309
294
  }
310
295
  }
311
296
 
312
-
313
-
314
297
  } // anonymous namespace
315
298
 
316
-
317
-
318
-
319
-
320
299
  /*******************************************************
321
300
  * KNN driver functions
322
301
  *******************************************************/
@@ -326,268 +305,275 @@ int distance_compute_blas_query_bs = 4096;
326
305
  int distance_compute_blas_database_bs = 1024;
327
306
  int distance_compute_min_k_reservoir = 100;
328
307
 
329
- void knn_inner_product (const float * x,
330
- const float * y,
331
- size_t d, size_t nx, size_t ny,
332
- float_minheap_array_t * ha)
333
- {
308
+ void knn_inner_product(
309
+ const float* x,
310
+ const float* y,
311
+ size_t d,
312
+ size_t nx,
313
+ size_t ny,
314
+ float_minheap_array_t* ha) {
334
315
  if (ha->k < distance_compute_min_k_reservoir) {
335
316
  HeapResultHandler<CMin<float, int64_t>> res(
336
- ha->nh, ha->val, ha->ids, ha->k);
317
+ ha->nh, ha->val, ha->ids, ha->k);
337
318
  if (nx < distance_compute_blas_threshold) {
338
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
319
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
339
320
  } else {
340
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
321
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
341
322
  }
342
323
  } else {
343
324
  ReservoirResultHandler<CMin<float, int64_t>> res(
344
- ha->nh, ha->val, ha->ids, ha->k);
325
+ ha->nh, ha->val, ha->ids, ha->k);
345
326
  if (nx < distance_compute_blas_threshold) {
346
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
327
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
347
328
  } else {
348
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
329
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
349
330
  }
350
331
  }
351
332
  }
352
333
 
353
-
354
-
355
-
356
- void knn_L2sqr (
357
- const float * x,
358
- const float * y,
359
- size_t d, size_t nx, size_t ny,
360
- float_maxheap_array_t * ha,
361
- const float *y_norm2
362
- ) {
363
-
334
+ void knn_L2sqr(
335
+ const float* x,
336
+ const float* y,
337
+ size_t d,
338
+ size_t nx,
339
+ size_t ny,
340
+ float_maxheap_array_t* ha,
341
+ const float* y_norm2) {
364
342
  if (ha->k < distance_compute_min_k_reservoir) {
365
343
  HeapResultHandler<CMax<float, int64_t>> res(
366
- ha->nh, ha->val, ha->ids, ha->k);
344
+ ha->nh, ha->val, ha->ids, ha->k);
367
345
 
368
346
  if (nx < distance_compute_blas_threshold) {
369
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
347
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
370
348
  } else {
371
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
349
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
372
350
  }
373
351
  } else {
374
352
  ReservoirResultHandler<CMax<float, int64_t>> res(
375
- ha->nh, ha->val, ha->ids, ha->k);
353
+ ha->nh, ha->val, ha->ids, ha->k);
376
354
  if (nx < distance_compute_blas_threshold) {
377
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
355
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
378
356
  } else {
379
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
357
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
380
358
  }
381
359
  }
382
360
  }
383
361
 
384
-
385
362
  /***************************************************************************
386
363
  * Range search
387
364
  ***************************************************************************/
388
365
 
389
-
390
-
391
-
392
- void range_search_L2sqr (
393
- const float * x,
394
- const float * y,
395
- size_t d, size_t nx, size_t ny,
366
+ void range_search_L2sqr(
367
+ const float* x,
368
+ const float* y,
369
+ size_t d,
370
+ size_t nx,
371
+ size_t ny,
396
372
  float radius,
397
- RangeSearchResult *res)
398
- {
373
+ RangeSearchResult* res) {
399
374
  RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
400
375
  if (nx < distance_compute_blas_threshold) {
401
- exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
376
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
402
377
  } else {
403
- exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
378
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
404
379
  }
405
380
  }
406
381
 
407
- void range_search_inner_product (
408
- const float * x,
409
- const float * y,
410
- size_t d, size_t nx, size_t ny,
382
+ void range_search_inner_product(
383
+ const float* x,
384
+ const float* y,
385
+ size_t d,
386
+ size_t nx,
387
+ size_t ny,
411
388
  float radius,
412
- RangeSearchResult *res)
413
- {
414
-
389
+ RangeSearchResult* res) {
415
390
  RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
416
391
  if (nx < distance_compute_blas_threshold) {
417
- exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
392
+ exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
418
393
  } else {
419
- exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
394
+ exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
420
395
  }
421
396
  }
422
397
 
423
-
424
398
  /***************************************************************************
425
399
  * compute a subset of distances
426
400
  ***************************************************************************/
427
401
 
428
402
  /* compute the inner product between x and a subset y of ny vectors,
429
403
  whose indices are given by idy. */
430
- void fvec_inner_products_by_idx (float * __restrict ip,
431
- const float * x,
432
- const float * y,
433
- const int64_t * __restrict ids, /* for y vecs */
434
- size_t d, size_t nx, size_t ny)
435
- {
404
+ void fvec_inner_products_by_idx(
405
+ float* __restrict ip,
406
+ const float* x,
407
+ const float* y,
408
+ const int64_t* __restrict ids, /* for y vecs */
409
+ size_t d,
410
+ size_t nx,
411
+ size_t ny) {
436
412
  #pragma omp parallel for
437
413
  for (int64_t j = 0; j < nx; j++) {
438
- const int64_t * __restrict idsj = ids + j * ny;
439
- const float * xj = x + j * d;
440
- float * __restrict ipj = ip + j * ny;
414
+ const int64_t* __restrict idsj = ids + j * ny;
415
+ const float* xj = x + j * d;
416
+ float* __restrict ipj = ip + j * ny;
441
417
  for (size_t i = 0; i < ny; i++) {
442
418
  if (idsj[i] < 0)
443
419
  continue;
444
- ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
420
+ ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
445
421
  }
446
422
  }
447
423
  }
448
424
 
449
-
450
-
451
425
  /* compute the inner product between x and a subset y of ny vectors,
452
426
  whose indices are given by idy. */
453
- void fvec_L2sqr_by_idx (float * __restrict dis,
454
- const float * x,
455
- const float * y,
456
- const int64_t * __restrict ids, /* ids of y vecs */
457
- size_t d, size_t nx, size_t ny)
458
- {
427
+ void fvec_L2sqr_by_idx(
428
+ float* __restrict dis,
429
+ const float* x,
430
+ const float* y,
431
+ const int64_t* __restrict ids, /* ids of y vecs */
432
+ size_t d,
433
+ size_t nx,
434
+ size_t ny) {
459
435
  #pragma omp parallel for
460
436
  for (int64_t j = 0; j < nx; j++) {
461
- const int64_t * __restrict idsj = ids + j * ny;
462
- const float * xj = x + j * d;
463
- float * __restrict disj = dis + j * ny;
437
+ const int64_t* __restrict idsj = ids + j * ny;
438
+ const float* xj = x + j * d;
439
+ float* __restrict disj = dis + j * ny;
464
440
  for (size_t i = 0; i < ny; i++) {
465
441
  if (idsj[i] < 0)
466
442
  continue;
467
- disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
443
+ disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
468
444
  }
469
445
  }
470
446
  }
471
447
 
472
- void pairwise_indexed_L2sqr (
473
- size_t d, size_t n,
474
- const float * x, const int64_t *ix,
475
- const float * y, const int64_t *iy,
476
- float *dis)
477
- {
448
+ void pairwise_indexed_L2sqr(
449
+ size_t d,
450
+ size_t n,
451
+ const float* x,
452
+ const int64_t* ix,
453
+ const float* y,
454
+ const int64_t* iy,
455
+ float* dis) {
478
456
  #pragma omp parallel for
479
457
  for (int64_t j = 0; j < n; j++) {
480
458
  if (ix[j] >= 0 && iy[j] >= 0) {
481
- dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
459
+ dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
482
460
  }
483
461
  }
484
462
  }
485
463
 
486
- void pairwise_indexed_inner_product (
487
- size_t d, size_t n,
488
- const float * x, const int64_t *ix,
489
- const float * y, const int64_t *iy,
490
- float *dis)
491
- {
464
+ void pairwise_indexed_inner_product(
465
+ size_t d,
466
+ size_t n,
467
+ const float* x,
468
+ const int64_t* ix,
469
+ const float* y,
470
+ const int64_t* iy,
471
+ float* dis) {
492
472
  #pragma omp parallel for
493
473
  for (int64_t j = 0; j < n; j++) {
494
474
  if (ix[j] >= 0 && iy[j] >= 0) {
495
- dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
475
+ dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
496
476
  }
497
477
  }
498
478
  }
499
479
 
500
-
501
480
  /* Find the nearest neighbors for nx queries in a set of ny vectors
502
481
  indexed by ids. May be useful for re-ranking a pre-selected vector list */
503
- void knn_inner_products_by_idx (const float * x,
504
- const float * y,
505
- const int64_t * ids,
506
- size_t d, size_t nx, size_t ny,
507
- float_minheap_array_t * res)
508
- {
482
+ void knn_inner_products_by_idx(
483
+ const float* x,
484
+ const float* y,
485
+ const int64_t* ids,
486
+ size_t d,
487
+ size_t nx,
488
+ size_t ny,
489
+ float_minheap_array_t* res) {
509
490
  size_t k = res->k;
510
491
 
511
492
  #pragma omp parallel for
512
493
  for (int64_t i = 0; i < nx; i++) {
513
- const float * x_ = x + i * d;
514
- const int64_t * idsi = ids + i * ny;
494
+ const float* x_ = x + i * d;
495
+ const int64_t* idsi = ids + i * ny;
515
496
  size_t j;
516
- float * __restrict simi = res->get_val(i);
517
- int64_t * __restrict idxi = res->get_ids (i);
518
- minheap_heapify (k, simi, idxi);
497
+ float* __restrict simi = res->get_val(i);
498
+ int64_t* __restrict idxi = res->get_ids(i);
499
+ minheap_heapify(k, simi, idxi);
519
500
 
520
501
  for (j = 0; j < ny; j++) {
521
- if (idsi[j] < 0) break;
522
- float ip = fvec_inner_product (x_, y + d * idsi[j], d);
502
+ if (idsi[j] < 0)
503
+ break;
504
+ float ip = fvec_inner_product(x_, y + d * idsi[j], d);
523
505
 
524
506
  if (ip > simi[0]) {
525
- minheap_replace_top (k, simi, idxi, ip, idsi[j]);
507
+ minheap_replace_top(k, simi, idxi, ip, idsi[j]);
526
508
  }
527
509
  }
528
- minheap_reorder (k, simi, idxi);
510
+ minheap_reorder(k, simi, idxi);
529
511
  }
530
-
531
512
  }
532
513
 
533
- void knn_L2sqr_by_idx (const float * x,
534
- const float * y,
535
- const int64_t * __restrict ids,
536
- size_t d, size_t nx, size_t ny,
537
- float_maxheap_array_t * res)
538
- {
514
+ void knn_L2sqr_by_idx(
515
+ const float* x,
516
+ const float* y,
517
+ const int64_t* __restrict ids,
518
+ size_t d,
519
+ size_t nx,
520
+ size_t ny,
521
+ float_maxheap_array_t* res) {
539
522
  size_t k = res->k;
540
523
 
541
524
  #pragma omp parallel for
542
525
  for (int64_t i = 0; i < nx; i++) {
543
- const float * x_ = x + i * d;
544
- const int64_t * __restrict idsi = ids + i * ny;
545
- float * __restrict simi = res->get_val(i);
546
- int64_t * __restrict idxi = res->get_ids (i);
547
- maxheap_heapify (res->k, simi, idxi);
526
+ const float* x_ = x + i * d;
527
+ const int64_t* __restrict idsi = ids + i * ny;
528
+ float* __restrict simi = res->get_val(i);
529
+ int64_t* __restrict idxi = res->get_ids(i);
530
+ maxheap_heapify(res->k, simi, idxi);
548
531
  for (size_t j = 0; j < ny; j++) {
549
- float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
532
+ float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
550
533
 
551
534
  if (disij < simi[0]) {
552
- maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
535
+ maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
553
536
  }
554
537
  }
555
- maxheap_reorder (res->k, simi, idxi);
538
+ maxheap_reorder(res->k, simi, idxi);
556
539
  }
557
-
558
540
  }
559
541
 
560
-
561
-
562
-
563
-
564
- void pairwise_L2sqr (int64_t d,
565
- int64_t nq, const float *xq,
566
- int64_t nb, const float *xb,
567
- float *dis,
568
- int64_t ldq, int64_t ldb, int64_t ldd)
569
- {
570
- if (nq == 0 || nb == 0) return;
571
- if (ldq == -1) ldq = d;
572
- if (ldb == -1) ldb = d;
573
- if (ldd == -1) ldd = nb;
542
+ void pairwise_L2sqr(
543
+ int64_t d,
544
+ int64_t nq,
545
+ const float* xq,
546
+ int64_t nb,
547
+ const float* xb,
548
+ float* dis,
549
+ int64_t ldq,
550
+ int64_t ldb,
551
+ int64_t ldd) {
552
+ if (nq == 0 || nb == 0)
553
+ return;
554
+ if (ldq == -1)
555
+ ldq = d;
556
+ if (ldb == -1)
557
+ ldb = d;
558
+ if (ldd == -1)
559
+ ldd = nb;
574
560
 
575
561
  // store in beginning of distance matrix to avoid malloc
576
- float *b_norms = dis;
562
+ float* b_norms = dis;
577
563
 
578
564
  #pragma omp parallel for
579
565
  for (int64_t i = 0; i < nb; i++)
580
- b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
566
+ b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
581
567
 
582
568
  #pragma omp parallel for
583
569
  for (int64_t i = 1; i < nq; i++) {
584
- float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
570
+ float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
585
571
  for (int64_t j = 0; j < nb; j++)
586
- dis[i * ldd + j] = q_norm + b_norms [j];
572
+ dis[i * ldd + j] = q_norm + b_norms[j];
587
573
  }
588
574
 
589
575
  {
590
- float q_norm = fvec_norm_L2sqr (xq, d);
576
+ float q_norm = fvec_norm_L2sqr(xq, d);
591
577
  for (int64_t j = 0; j < nb; j++)
592
578
  dis[j] += q_norm;
593
579
  }
@@ -596,22 +582,28 @@ void pairwise_L2sqr (int64_t d,
596
582
  FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
597
583
  float one = 1.0, minus_2 = -2.0;
598
584
 
599
- sgemm_ ("Transposed", "Not transposed",
600
- &nbi, &nqi, &di,
601
- &minus_2,
602
- xb, &ldbi,
603
- xq, &ldqi,
604
- &one, dis, &lddi);
585
+ sgemm_("Transposed",
586
+ "Not transposed",
587
+ &nbi,
588
+ &nqi,
589
+ &di,
590
+ &minus_2,
591
+ xb,
592
+ &ldbi,
593
+ xq,
594
+ &ldqi,
595
+ &one,
596
+ dis,
597
+ &lddi);
605
598
  }
606
-
607
599
  }
608
600
 
609
- void inner_product_to_L2sqr(float* __restrict dis,
610
- const float* nr1,
611
- const float* nr2,
612
- size_t n1, size_t n2)
613
- {
614
-
601
+ void inner_product_to_L2sqr(
602
+ float* __restrict dis,
603
+ const float* nr1,
604
+ const float* nr2,
605
+ size_t n1,
606
+ size_t n2) {
615
607
  #pragma omp parallel for
616
608
  for (int64_t j = 0; j < n1; j++) {
617
609
  float* disj = dis + j * n2;
@@ -620,5 +612,4 @@ void inner_product_to_L2sqr(float* __restrict dis,
620
612
  }
621
613
  }
622
614
 
623
-
624
615
  } // namespace faiss