faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -5,9 +5,8 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
- #include <faiss/utils/WorkerThread.h>
10
8
  #include <faiss/impl/FaissAssert.h>
9
+ #include <faiss/utils/WorkerThread.h>
11
10
  #include <exception>
12
11
 
13
12
  namespace faiss {
@@ -15,112 +14,104 @@ namespace faiss {
15
14
  namespace {
16
15
 
17
16
  // Captures any exceptions thrown by the lambda and returns them via the promise
18
- void runCallback(std::function<void()>& fn,
19
- std::promise<bool>& promise) {
20
- try {
21
- fn();
22
- promise.set_value(true);
23
- } catch (...) {
24
- promise.set_exception(std::current_exception());
25
- }
17
+ void runCallback(std::function<void()>& fn, std::promise<bool>& promise) {
18
+ try {
19
+ fn();
20
+ promise.set_value(true);
21
+ } catch (...) {
22
+ promise.set_exception(std::current_exception());
23
+ }
26
24
  }
27
25
 
28
26
  } // namespace
29
27
 
30
- WorkerThread::WorkerThread() :
31
- wantStop_(false) {
32
- startThread();
28
+ WorkerThread::WorkerThread() : wantStop_(false) {
29
+ startThread();
33
30
 
34
- // Make sure that the thread has started before continuing
35
- add([](){}).get();
31
+ // Make sure that the thread has started before continuing
32
+ add([]() {}).get();
36
33
  }
37
34
 
38
35
  WorkerThread::~WorkerThread() {
39
- stop();
40
- waitForThreadExit();
36
+ stop();
37
+ waitForThreadExit();
41
38
  }
42
39
 
43
- void
44
- WorkerThread::startThread() {
45
- thread_ = std::thread([this](){ threadMain(); });
40
+ void WorkerThread::startThread() {
41
+ thread_ = std::thread([this]() { threadMain(); });
46
42
  }
47
43
 
48
- void
49
- WorkerThread::stop() {
50
- std::lock_guard<std::mutex> guard(mutex_);
44
+ void WorkerThread::stop() {
45
+ std::lock_guard<std::mutex> guard(mutex_);
51
46
 
52
- wantStop_ = true;
53
- monitor_.notify_one();
47
+ wantStop_ = true;
48
+ monitor_.notify_one();
54
49
  }
55
50
 
56
- std::future<bool>
57
- WorkerThread::add(std::function<void()> f) {
58
- std::lock_guard<std::mutex> guard(mutex_);
51
+ std::future<bool> WorkerThread::add(std::function<void()> f) {
52
+ std::lock_guard<std::mutex> guard(mutex_);
59
53
 
60
- if (wantStop_) {
61
- // The timer thread has been stopped, or we want to stop; we can't
62
- // schedule anything else
63
- std::promise<bool> p;
64
- auto fut = p.get_future();
54
+ if (wantStop_) {
55
+ // The timer thread has been stopped, or we want to stop; we can't
56
+ // schedule anything else
57
+ std::promise<bool> p;
58
+ auto fut = p.get_future();
65
59
 
66
- // did not execute
67
- p.set_value(false);
68
- return fut;
69
- }
60
+ // did not execute
61
+ p.set_value(false);
62
+ return fut;
63
+ }
70
64
 
71
- auto pr = std::promise<bool>();
72
- auto fut = pr.get_future();
65
+ auto pr = std::promise<bool>();
66
+ auto fut = pr.get_future();
73
67
 
74
- queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
68
+ queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
75
69
 
76
- // Wake up our thread
77
- monitor_.notify_one();
78
- return fut;
70
+ // Wake up our thread
71
+ monitor_.notify_one();
72
+ return fut;
79
73
  }
80
74
 
81
- void
82
- WorkerThread::threadMain() {
83
- threadLoop();
75
+ void WorkerThread::threadMain() {
76
+ threadLoop();
84
77
 
85
- // Call all pending tasks
86
- FAISS_ASSERT(wantStop_);
78
+ // Call all pending tasks
79
+ FAISS_ASSERT(wantStop_);
87
80
 
88
- // flush all pending operations
89
- for (auto& f : queue_) {
90
- runCallback(f.first, f.second);
91
- }
81
+ // flush all pending operations
82
+ for (auto& f : queue_) {
83
+ runCallback(f.first, f.second);
84
+ }
92
85
  }
93
86
 
94
- void
95
- WorkerThread::threadLoop() {
96
- while (true) {
97
- std::pair<std::function<void()>, std::promise<bool>> data;
87
+ void WorkerThread::threadLoop() {
88
+ while (true) {
89
+ std::pair<std::function<void()>, std::promise<bool>> data;
98
90
 
99
- {
100
- std::unique_lock<std::mutex> lock(mutex_);
91
+ {
92
+ std::unique_lock<std::mutex> lock(mutex_);
101
93
 
102
- while (!wantStop_ && queue_.empty()) {
103
- monitor_.wait(lock);
104
- }
94
+ while (!wantStop_ && queue_.empty()) {
95
+ monitor_.wait(lock);
96
+ }
105
97
 
106
- if (wantStop_) {
107
- return;
108
- }
98
+ if (wantStop_) {
99
+ return;
100
+ }
109
101
 
110
- data = std::move(queue_.front());
111
- queue_.pop_front();
112
- }
102
+ data = std::move(queue_.front());
103
+ queue_.pop_front();
104
+ }
113
105
 
114
- runCallback(data.first, data.second);
115
- }
106
+ runCallback(data.first, data.second);
107
+ }
116
108
  }
117
109
 
118
- void
119
- WorkerThread::waitForThreadExit() {
120
- try {
121
- thread_.join();
122
- } catch (...) {
123
- }
110
+ void WorkerThread::waitForThreadExit() {
111
+ try {
112
+ thread_.join();
113
+ } catch (...) {
114
+ }
124
115
  }
125
116
 
126
- } // namespace
117
+ } // namespace faiss
@@ -5,57 +5,56 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <condition_variable>
12
- #include <future>
13
11
  #include <deque>
12
+ #include <future>
14
13
  #include <thread>
15
14
 
16
15
  namespace faiss {
17
16
 
18
17
  class WorkerThread {
19
- public:
20
- WorkerThread();
18
+ public:
19
+ WorkerThread();
21
20
 
22
- /// Stops and waits for the worker thread to exit, flushing all
23
- /// pending lambdas
24
- ~WorkerThread();
21
+ /// Stops and waits for the worker thread to exit, flushing all
22
+ /// pending lambdas
23
+ ~WorkerThread();
25
24
 
26
- /// Request that the worker thread stop itself
27
- void stop();
25
+ /// Request that the worker thread stop itself
26
+ void stop();
28
27
 
29
- /// Blocking waits in the current thread for the worker thread to
30
- /// stop
31
- void waitForThreadExit();
28
+ /// Blocking waits in the current thread for the worker thread to
29
+ /// stop
30
+ void waitForThreadExit();
32
31
 
33
- /// Adds a lambda to run on the worker thread; returns a future that
34
- /// can be used to block on its completion.
35
- /// Future status is `true` if the lambda was run in the worker
36
- /// thread; `false` if it was not run, because the worker thread is
37
- /// exiting or has exited.
38
- std::future<bool> add(std::function<void()> f);
32
+ /// Adds a lambda to run on the worker thread; returns a future that
33
+ /// can be used to block on its completion.
34
+ /// Future status is `true` if the lambda was run in the worker
35
+ /// thread; `false` if it was not run, because the worker thread is
36
+ /// exiting or has exited.
37
+ std::future<bool> add(std::function<void()> f);
39
38
 
40
- private:
41
- void startThread();
42
- void threadMain();
43
- void threadLoop();
39
+ private:
40
+ void startThread();
41
+ void threadMain();
42
+ void threadLoop();
44
43
 
45
- /// Thread that all queued lambdas are run on
46
- std::thread thread_;
44
+ /// Thread that all queued lambdas are run on
45
+ std::thread thread_;
47
46
 
48
- /// Mutex for the queue and exit status
49
- std::mutex mutex_;
47
+ /// Mutex for the queue and exit status
48
+ std::mutex mutex_;
50
49
 
51
- /// Monitor for the exit status and the queue
52
- std::condition_variable monitor_;
50
+ /// Monitor for the exit status and the queue
51
+ std::condition_variable monitor_;
53
52
 
54
- /// Whether or not we want the thread to exit
55
- bool wantStop_;
53
+ /// Whether or not we want the thread to exit
54
+ bool wantStop_;
56
55
 
57
- /// Queue of pending lambdas to call
58
- std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
56
+ /// Queue of pending lambdas to call
57
+ std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
59
58
  };
60
59
 
61
- } // namespace
60
+ } // namespace faiss
@@ -10,10 +10,10 @@
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
12
  #include <algorithm>
13
- #include <cstdio>
14
13
  #include <cassert>
15
- #include <cstring>
16
14
  #include <cmath>
15
+ #include <cstdio>
16
+ #include <cstring>
17
17
 
18
18
  #include <omp.h>
19
19
 
@@ -21,186 +21,151 @@
21
21
  #include <faiss/impl/FaissAssert.h>
22
22
  #include <faiss/impl/ResultHandler.h>
23
23
 
24
-
25
-
26
24
  #ifndef FINTEGER
27
25
  #define FINTEGER long
28
26
  #endif
29
27
 
30
-
31
28
  extern "C" {
32
29
 
33
30
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
34
31
 
35
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
36
- n, FINTEGER *k, const float *alpha, const float *a,
37
- FINTEGER *lda, const float *b, FINTEGER *
38
- ldb, float *beta, float *c, FINTEGER *ldc);
39
-
40
-
32
+ int sgemm_(
33
+ const char* transa,
34
+ const char* transb,
35
+ FINTEGER* m,
36
+ FINTEGER* n,
37
+ FINTEGER* k,
38
+ const float* alpha,
39
+ const float* a,
40
+ FINTEGER* lda,
41
+ const float* b,
42
+ FINTEGER* ldb,
43
+ float* beta,
44
+ float* c,
45
+ FINTEGER* ldc);
41
46
  }
42
47
 
43
-
44
48
  namespace faiss {
45
49
 
46
-
47
-
48
50
  /***************************************************************************
49
51
  * Matrix/vector ops
50
52
  ***************************************************************************/
51
53
 
52
-
53
-
54
-
55
54
  /* Compute the L2 norm of a set of nx vectors */
56
- void fvec_norms_L2 (float * __restrict nr,
57
- const float * __restrict x,
58
- size_t d, size_t nx)
59
- {
60
-
55
+ void fvec_norms_L2(
56
+ float* __restrict nr,
57
+ const float* __restrict x,
58
+ size_t d,
59
+ size_t nx) {
61
60
  #pragma omp parallel for
62
61
  for (int64_t i = 0; i < nx; i++) {
63
- nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
62
+ nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
64
63
  }
65
64
  }
66
65
 
67
- void fvec_norms_L2sqr (float * __restrict nr,
68
- const float * __restrict x,
69
- size_t d, size_t nx)
70
- {
66
+ void fvec_norms_L2sqr(
67
+ float* __restrict nr,
68
+ const float* __restrict x,
69
+ size_t d,
70
+ size_t nx) {
71
71
  #pragma omp parallel for
72
72
  for (int64_t i = 0; i < nx; i++)
73
- nr[i] = fvec_norm_L2sqr (x + i * d, d);
73
+ nr[i] = fvec_norm_L2sqr(x + i * d, d);
74
74
  }
75
75
 
76
-
77
-
78
- void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
79
- {
76
+ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
80
77
  #pragma omp parallel for
81
78
  for (int64_t i = 0; i < nx; i++) {
82
- float * __restrict xi = x + i * d;
79
+ float* __restrict xi = x + i * d;
83
80
 
84
- float nr = fvec_norm_L2sqr (xi, d);
81
+ float nr = fvec_norm_L2sqr(xi, d);
85
82
 
86
83
  if (nr > 0) {
87
84
  size_t j;
88
- const float inv_nr = 1.0 / sqrtf (nr);
85
+ const float inv_nr = 1.0 / sqrtf(nr);
89
86
  for (j = 0; j < d; j++)
90
87
  xi[j] *= inv_nr;
91
88
  }
92
89
  }
93
90
  }
94
91
 
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
-
106
92
  /***************************************************************************
107
93
  * KNN functions
108
94
  ***************************************************************************/
109
95
 
110
96
  namespace {
111
97
 
112
-
113
-
114
98
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
115
- template<class ResultHandler>
116
- void exhaustive_inner_product_seq (
117
- const float * x,
118
- const float * y,
119
- size_t d, size_t nx, size_t ny,
120
- ResultHandler &res)
121
- {
122
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
123
-
124
- check_period *= omp_get_max_threads();
125
-
99
+ template <class ResultHandler>
100
+ void exhaustive_inner_product_seq(
101
+ const float* x,
102
+ const float* y,
103
+ size_t d,
104
+ size_t nx,
105
+ size_t ny,
106
+ ResultHandler& res) {
126
107
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
127
108
 
128
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
129
- size_t i1 = std::min(i0 + check_period, nx);
130
-
131
109
  #pragma omp parallel
132
- {
133
- SingleResultHandler resi(res);
110
+ {
111
+ SingleResultHandler resi(res);
134
112
  #pragma omp for
135
- for (int64_t i = i0; i < i1; i++) {
136
- const float * x_i = x + i * d;
137
- const float * y_j = y;
113
+ for (int64_t i = 0; i < nx; i++) {
114
+ const float* x_i = x + i * d;
115
+ const float* y_j = y;
138
116
 
139
- resi.begin(i);
117
+ resi.begin(i);
140
118
 
141
- for (size_t j = 0; j < ny; j++) {
142
- float ip = fvec_inner_product (x_i, y_j, d);
143
- resi.add_result(ip, j);
144
- y_j += d;
145
- }
146
- resi.end();
119
+ for (size_t j = 0; j < ny; j++) {
120
+ float ip = fvec_inner_product(x_i, y_j, d);
121
+ resi.add_result(ip, j);
122
+ y_j += d;
147
123
  }
124
+ resi.end();
148
125
  }
149
- InterruptCallback::check ();
150
126
  }
151
-
152
127
  }
153
128
 
154
- template<class ResultHandler>
155
- void exhaustive_L2sqr_seq (
156
- const float * x,
157
- const float * y,
158
- size_t d, size_t nx, size_t ny,
159
- ResultHandler & res)
160
- {
161
-
162
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
163
- check_period *= omp_get_max_threads();
129
+ template <class ResultHandler>
130
+ void exhaustive_L2sqr_seq(
131
+ const float* x,
132
+ const float* y,
133
+ size_t d,
134
+ size_t nx,
135
+ size_t ny,
136
+ ResultHandler& res) {
164
137
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
165
138
 
166
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
167
- size_t i1 = std::min(i0 + check_period, nx);
168
-
169
139
  #pragma omp parallel
170
- {
171
- SingleResultHandler resi(res);
140
+ {
141
+ SingleResultHandler resi(res);
172
142
  #pragma omp for
173
- for (int64_t i = i0; i < i1; i++) {
174
- const float * x_i = x + i * d;
175
- const float * y_j = y;
176
- resi.begin(i);
177
- for (size_t j = 0; j < ny; j++) {
178
- float disij = fvec_L2sqr (x_i, y_j, d);
179
- resi.add_result(disij, j);
180
- y_j += d;
181
- }
182
- resi.end();
143
+ for (int64_t i = 0; i < nx; i++) {
144
+ const float* x_i = x + i * d;
145
+ const float* y_j = y;
146
+ resi.begin(i);
147
+ for (size_t j = 0; j < ny; j++) {
148
+ float disij = fvec_L2sqr(x_i, y_j, d);
149
+ resi.add_result(disij, j);
150
+ y_j += d;
183
151
  }
152
+ resi.end();
184
153
  }
185
- InterruptCallback::check ();
186
154
  }
187
-
188
- };
189
-
190
-
191
-
192
-
155
+ }
193
156
 
194
157
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
195
- template<class ResultHandler>
196
- void exhaustive_inner_product_blas (
197
- const float * x,
198
- const float * y,
199
- size_t d, size_t nx, size_t ny,
200
- ResultHandler & res)
201
- {
158
+ template <class ResultHandler>
159
+ void exhaustive_inner_product_blas(
160
+ const float* x,
161
+ const float* y,
162
+ size_t d,
163
+ size_t nx,
164
+ size_t ny,
165
+ ResultHandler& res) {
202
166
  // BLAS does not like empty matrices
203
- if (nx == 0 || ny == 0) return;
167
+ if (nx == 0 || ny == 0)
168
+ return;
204
169
 
205
170
  /* block sizes */
206
171
  const size_t bs_x = distance_compute_blas_query_bs;
@@ -209,86 +174,105 @@ void exhaustive_inner_product_blas (
209
174
 
210
175
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
211
176
  size_t i1 = i0 + bs_x;
212
- if(i1 > nx) i1 = nx;
177
+ if (i1 > nx)
178
+ i1 = nx;
213
179
 
214
180
  res.begin_multiple(i0, i1);
215
181
 
216
182
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
217
183
  size_t j1 = j0 + bs_y;
218
- if (j1 > ny) j1 = ny;
184
+ if (j1 > ny)
185
+ j1 = ny;
219
186
  /* compute the actual dot products */
220
187
  {
221
188
  float one = 1, zero = 0;
222
189
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
223
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
224
- y + j0 * d, &di,
225
- x + i0 * d, &di, &zero,
226
- ip_block.get(), &nyi);
190
+ sgemm_("Transpose",
191
+ "Not transpose",
192
+ &nyi,
193
+ &nxi,
194
+ &di,
195
+ &one,
196
+ y + j0 * d,
197
+ &di,
198
+ x + i0 * d,
199
+ &di,
200
+ &zero,
201
+ ip_block.get(),
202
+ &nyi);
227
203
  }
228
204
 
229
205
  res.add_results(j0, j1, ip_block.get());
230
-
231
206
  }
232
207
  res.end_multiple();
233
- InterruptCallback::check ();
234
-
208
+ InterruptCallback::check();
235
209
  }
236
210
  }
237
211
 
238
-
239
-
240
-
241
212
  // distance correction is an operator that can be applied to transform
242
213
  // the distances
243
- template<class ResultHandler>
244
- void exhaustive_L2sqr_blas (
245
- const float * x,
246
- const float * y,
247
- size_t d, size_t nx, size_t ny,
248
- ResultHandler & res,
249
- const float *y_norms = nullptr)
250
- {
214
+ template <class ResultHandler>
215
+ void exhaustive_L2sqr_blas(
216
+ const float* x,
217
+ const float* y,
218
+ size_t d,
219
+ size_t nx,
220
+ size_t ny,
221
+ ResultHandler& res,
222
+ const float* y_norms = nullptr) {
251
223
  // BLAS does not like empty matrices
252
- if (nx == 0 || ny == 0) return;
224
+ if (nx == 0 || ny == 0)
225
+ return;
253
226
 
254
227
  /* block sizes */
255
228
  const size_t bs_x = distance_compute_blas_query_bs;
256
229
  const size_t bs_y = distance_compute_blas_database_bs;
257
230
  // const size_t bs_x = 16, bs_y = 16;
258
- std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
259
- std::unique_ptr<float []> x_norms(new float[nx]);
260
- std::unique_ptr<float []> del2;
231
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
232
+ std::unique_ptr<float[]> x_norms(new float[nx]);
233
+ std::unique_ptr<float[]> del2;
261
234
 
262
- fvec_norms_L2sqr (x_norms.get(), x, d, nx);
235
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
263
236
 
264
237
  if (!y_norms) {
265
- float *y_norms2 = new float[ny];
238
+ float* y_norms2 = new float[ny];
266
239
  del2.reset(y_norms2);
267
- fvec_norms_L2sqr (y_norms2, y, d, ny);
240
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
268
241
  y_norms = y_norms2;
269
242
  }
270
243
 
271
244
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
272
245
  size_t i1 = i0 + bs_x;
273
- if(i1 > nx) i1 = nx;
246
+ if (i1 > nx)
247
+ i1 = nx;
274
248
 
275
249
  res.begin_multiple(i0, i1);
276
250
 
277
251
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
278
252
  size_t j1 = j0 + bs_y;
279
- if (j1 > ny) j1 = ny;
253
+ if (j1 > ny)
254
+ j1 = ny;
280
255
  /* compute the actual dot products */
281
256
  {
282
257
  float one = 1, zero = 0;
283
258
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
284
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
285
- y + j0 * d, &di,
286
- x + i0 * d, &di, &zero,
287
- ip_block.get(), &nyi);
259
+ sgemm_("Transpose",
260
+ "Not transpose",
261
+ &nyi,
262
+ &nxi,
263
+ &di,
264
+ &one,
265
+ y + j0 * d,
266
+ &di,
267
+ x + i0 * d,
268
+ &di,
269
+ &zero,
270
+ ip_block.get(),
271
+ &nyi);
288
272
  }
289
-
273
+ #pragma omp parallel for
290
274
  for (int64_t i = i0; i < i1; i++) {
291
- float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
275
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
292
276
 
293
277
  for (size_t j = j0; j < j1; j++) {
294
278
  float ip = *ip_line;
@@ -296,7 +280,8 @@ void exhaustive_L2sqr_blas (
296
280
 
297
281
  // negative values can occur for identical vectors
298
282
  // due to roundoff errors
299
- if (dis < 0) dis = 0;
283
+ if (dis < 0)
284
+ dis = 0;
300
285
 
301
286
  *ip_line = dis;
302
287
  ip_line++;
@@ -305,18 +290,12 @@ void exhaustive_L2sqr_blas (
305
290
  res.add_results(j0, j1, ip_block.get());
306
291
  }
307
292
  res.end_multiple();
308
- InterruptCallback::check ();
293
+ InterruptCallback::check();
309
294
  }
310
295
  }
311
296
 
312
-
313
-
314
297
  } // anonymous namespace
315
298
 
316
-
317
-
318
-
319
-
320
299
  /*******************************************************
321
300
  * KNN driver functions
322
301
  *******************************************************/
@@ -326,268 +305,275 @@ int distance_compute_blas_query_bs = 4096;
326
305
  int distance_compute_blas_database_bs = 1024;
327
306
  int distance_compute_min_k_reservoir = 100;
328
307
 
329
- void knn_inner_product (const float * x,
330
- const float * y,
331
- size_t d, size_t nx, size_t ny,
332
- float_minheap_array_t * ha)
333
- {
308
+ void knn_inner_product(
309
+ const float* x,
310
+ const float* y,
311
+ size_t d,
312
+ size_t nx,
313
+ size_t ny,
314
+ float_minheap_array_t* ha) {
334
315
  if (ha->k < distance_compute_min_k_reservoir) {
335
316
  HeapResultHandler<CMin<float, int64_t>> res(
336
- ha->nh, ha->val, ha->ids, ha->k);
317
+ ha->nh, ha->val, ha->ids, ha->k);
337
318
  if (nx < distance_compute_blas_threshold) {
338
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
319
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
339
320
  } else {
340
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
321
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
341
322
  }
342
323
  } else {
343
324
  ReservoirResultHandler<CMin<float, int64_t>> res(
344
- ha->nh, ha->val, ha->ids, ha->k);
325
+ ha->nh, ha->val, ha->ids, ha->k);
345
326
  if (nx < distance_compute_blas_threshold) {
346
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
327
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
347
328
  } else {
348
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
329
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
349
330
  }
350
331
  }
351
332
  }
352
333
 
353
-
354
-
355
-
356
- void knn_L2sqr (
357
- const float * x,
358
- const float * y,
359
- size_t d, size_t nx, size_t ny,
360
- float_maxheap_array_t * ha,
361
- const float *y_norm2
362
- ) {
363
-
334
+ void knn_L2sqr(
335
+ const float* x,
336
+ const float* y,
337
+ size_t d,
338
+ size_t nx,
339
+ size_t ny,
340
+ float_maxheap_array_t* ha,
341
+ const float* y_norm2) {
364
342
  if (ha->k < distance_compute_min_k_reservoir) {
365
343
  HeapResultHandler<CMax<float, int64_t>> res(
366
- ha->nh, ha->val, ha->ids, ha->k);
344
+ ha->nh, ha->val, ha->ids, ha->k);
367
345
 
368
346
  if (nx < distance_compute_blas_threshold) {
369
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
347
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
370
348
  } else {
371
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
349
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
372
350
  }
373
351
  } else {
374
352
  ReservoirResultHandler<CMax<float, int64_t>> res(
375
- ha->nh, ha->val, ha->ids, ha->k);
353
+ ha->nh, ha->val, ha->ids, ha->k);
376
354
  if (nx < distance_compute_blas_threshold) {
377
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
355
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
378
356
  } else {
379
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
357
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
380
358
  }
381
359
  }
382
360
  }
383
361
 
384
-
385
362
  /***************************************************************************
386
363
  * Range search
387
364
  ***************************************************************************/
388
365
 
389
-
390
-
391
-
392
- void range_search_L2sqr (
393
- const float * x,
394
- const float * y,
395
- size_t d, size_t nx, size_t ny,
366
+ void range_search_L2sqr(
367
+ const float* x,
368
+ const float* y,
369
+ size_t d,
370
+ size_t nx,
371
+ size_t ny,
396
372
  float radius,
397
- RangeSearchResult *res)
398
- {
373
+ RangeSearchResult* res) {
399
374
  RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
400
375
  if (nx < distance_compute_blas_threshold) {
401
- exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
376
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
402
377
  } else {
403
- exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
378
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
404
379
  }
405
380
  }
406
381
 
407
- void range_search_inner_product (
408
- const float * x,
409
- const float * y,
410
- size_t d, size_t nx, size_t ny,
382
+ void range_search_inner_product(
383
+ const float* x,
384
+ const float* y,
385
+ size_t d,
386
+ size_t nx,
387
+ size_t ny,
411
388
  float radius,
412
- RangeSearchResult *res)
413
- {
414
-
389
+ RangeSearchResult* res) {
415
390
  RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
416
391
  if (nx < distance_compute_blas_threshold) {
417
- exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
392
+ exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
418
393
  } else {
419
- exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
394
+ exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
420
395
  }
421
396
  }
422
397
 
423
-
424
398
  /***************************************************************************
425
399
  * compute a subset of distances
426
400
  ***************************************************************************/
427
401
 
428
402
  /* compute the inner product between x and a subset y of ny vectors,
429
403
  whose indices are given by idy. */
430
- void fvec_inner_products_by_idx (float * __restrict ip,
431
- const float * x,
432
- const float * y,
433
- const int64_t * __restrict ids, /* for y vecs */
434
- size_t d, size_t nx, size_t ny)
435
- {
404
+ void fvec_inner_products_by_idx(
405
+ float* __restrict ip,
406
+ const float* x,
407
+ const float* y,
408
+ const int64_t* __restrict ids, /* for y vecs */
409
+ size_t d,
410
+ size_t nx,
411
+ size_t ny) {
436
412
  #pragma omp parallel for
437
413
  for (int64_t j = 0; j < nx; j++) {
438
- const int64_t * __restrict idsj = ids + j * ny;
439
- const float * xj = x + j * d;
440
- float * __restrict ipj = ip + j * ny;
414
+ const int64_t* __restrict idsj = ids + j * ny;
415
+ const float* xj = x + j * d;
416
+ float* __restrict ipj = ip + j * ny;
441
417
  for (size_t i = 0; i < ny; i++) {
442
418
  if (idsj[i] < 0)
443
419
  continue;
444
- ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
420
+ ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
445
421
  }
446
422
  }
447
423
  }
448
424
 
449
-
450
-
451
425
  /* compute the inner product between x and a subset y of ny vectors,
452
426
  whose indices are given by idy. */
453
- void fvec_L2sqr_by_idx (float * __restrict dis,
454
- const float * x,
455
- const float * y,
456
- const int64_t * __restrict ids, /* ids of y vecs */
457
- size_t d, size_t nx, size_t ny)
458
- {
427
+ void fvec_L2sqr_by_idx(
428
+ float* __restrict dis,
429
+ const float* x,
430
+ const float* y,
431
+ const int64_t* __restrict ids, /* ids of y vecs */
432
+ size_t d,
433
+ size_t nx,
434
+ size_t ny) {
459
435
  #pragma omp parallel for
460
436
  for (int64_t j = 0; j < nx; j++) {
461
- const int64_t * __restrict idsj = ids + j * ny;
462
- const float * xj = x + j * d;
463
- float * __restrict disj = dis + j * ny;
437
+ const int64_t* __restrict idsj = ids + j * ny;
438
+ const float* xj = x + j * d;
439
+ float* __restrict disj = dis + j * ny;
464
440
  for (size_t i = 0; i < ny; i++) {
465
441
  if (idsj[i] < 0)
466
442
  continue;
467
- disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
443
+ disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
468
444
  }
469
445
  }
470
446
  }
471
447
 
472
- void pairwise_indexed_L2sqr (
473
- size_t d, size_t n,
474
- const float * x, const int64_t *ix,
475
- const float * y, const int64_t *iy,
476
- float *dis)
477
- {
448
+ void pairwise_indexed_L2sqr(
449
+ size_t d,
450
+ size_t n,
451
+ const float* x,
452
+ const int64_t* ix,
453
+ const float* y,
454
+ const int64_t* iy,
455
+ float* dis) {
478
456
  #pragma omp parallel for
479
457
  for (int64_t j = 0; j < n; j++) {
480
458
  if (ix[j] >= 0 && iy[j] >= 0) {
481
- dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
459
+ dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
482
460
  }
483
461
  }
484
462
  }
485
463
 
486
- void pairwise_indexed_inner_product (
487
- size_t d, size_t n,
488
- const float * x, const int64_t *ix,
489
- const float * y, const int64_t *iy,
490
- float *dis)
491
- {
464
+ void pairwise_indexed_inner_product(
465
+ size_t d,
466
+ size_t n,
467
+ const float* x,
468
+ const int64_t* ix,
469
+ const float* y,
470
+ const int64_t* iy,
471
+ float* dis) {
492
472
  #pragma omp parallel for
493
473
  for (int64_t j = 0; j < n; j++) {
494
474
  if (ix[j] >= 0 && iy[j] >= 0) {
495
- dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
475
+ dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
496
476
  }
497
477
  }
498
478
  }
499
479
 
500
-
501
480
  /* Find the nearest neighbors for nx queries in a set of ny vectors
502
481
  indexed by ids. May be useful for re-ranking a pre-selected vector list */
503
- void knn_inner_products_by_idx (const float * x,
504
- const float * y,
505
- const int64_t * ids,
506
- size_t d, size_t nx, size_t ny,
507
- float_minheap_array_t * res)
508
- {
482
+ void knn_inner_products_by_idx(
483
+ const float* x,
484
+ const float* y,
485
+ const int64_t* ids,
486
+ size_t d,
487
+ size_t nx,
488
+ size_t ny,
489
+ float_minheap_array_t* res) {
509
490
  size_t k = res->k;
510
491
 
511
492
  #pragma omp parallel for
512
493
  for (int64_t i = 0; i < nx; i++) {
513
- const float * x_ = x + i * d;
514
- const int64_t * idsi = ids + i * ny;
494
+ const float* x_ = x + i * d;
495
+ const int64_t* idsi = ids + i * ny;
515
496
  size_t j;
516
- float * __restrict simi = res->get_val(i);
517
- int64_t * __restrict idxi = res->get_ids (i);
518
- minheap_heapify (k, simi, idxi);
497
+ float* __restrict simi = res->get_val(i);
498
+ int64_t* __restrict idxi = res->get_ids(i);
499
+ minheap_heapify(k, simi, idxi);
519
500
 
520
501
  for (j = 0; j < ny; j++) {
521
- if (idsi[j] < 0) break;
522
- float ip = fvec_inner_product (x_, y + d * idsi[j], d);
502
+ if (idsi[j] < 0)
503
+ break;
504
+ float ip = fvec_inner_product(x_, y + d * idsi[j], d);
523
505
 
524
506
  if (ip > simi[0]) {
525
- minheap_replace_top (k, simi, idxi, ip, idsi[j]);
507
+ minheap_replace_top(k, simi, idxi, ip, idsi[j]);
526
508
  }
527
509
  }
528
- minheap_reorder (k, simi, idxi);
510
+ minheap_reorder(k, simi, idxi);
529
511
  }
530
-
531
512
  }
532
513
 
533
- void knn_L2sqr_by_idx (const float * x,
534
- const float * y,
535
- const int64_t * __restrict ids,
536
- size_t d, size_t nx, size_t ny,
537
- float_maxheap_array_t * res)
538
- {
514
+ void knn_L2sqr_by_idx(
515
+ const float* x,
516
+ const float* y,
517
+ const int64_t* __restrict ids,
518
+ size_t d,
519
+ size_t nx,
520
+ size_t ny,
521
+ float_maxheap_array_t* res) {
539
522
  size_t k = res->k;
540
523
 
541
524
  #pragma omp parallel for
542
525
  for (int64_t i = 0; i < nx; i++) {
543
- const float * x_ = x + i * d;
544
- const int64_t * __restrict idsi = ids + i * ny;
545
- float * __restrict simi = res->get_val(i);
546
- int64_t * __restrict idxi = res->get_ids (i);
547
- maxheap_heapify (res->k, simi, idxi);
526
+ const float* x_ = x + i * d;
527
+ const int64_t* __restrict idsi = ids + i * ny;
528
+ float* __restrict simi = res->get_val(i);
529
+ int64_t* __restrict idxi = res->get_ids(i);
530
+ maxheap_heapify(res->k, simi, idxi);
548
531
  for (size_t j = 0; j < ny; j++) {
549
- float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
532
+ float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
550
533
 
551
534
  if (disij < simi[0]) {
552
- maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
535
+ maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
553
536
  }
554
537
  }
555
- maxheap_reorder (res->k, simi, idxi);
538
+ maxheap_reorder(res->k, simi, idxi);
556
539
  }
557
-
558
540
  }
559
541
 
560
-
561
-
562
-
563
-
564
- void pairwise_L2sqr (int64_t d,
565
- int64_t nq, const float *xq,
566
- int64_t nb, const float *xb,
567
- float *dis,
568
- int64_t ldq, int64_t ldb, int64_t ldd)
569
- {
570
- if (nq == 0 || nb == 0) return;
571
- if (ldq == -1) ldq = d;
572
- if (ldb == -1) ldb = d;
573
- if (ldd == -1) ldd = nb;
542
+ void pairwise_L2sqr(
543
+ int64_t d,
544
+ int64_t nq,
545
+ const float* xq,
546
+ int64_t nb,
547
+ const float* xb,
548
+ float* dis,
549
+ int64_t ldq,
550
+ int64_t ldb,
551
+ int64_t ldd) {
552
+ if (nq == 0 || nb == 0)
553
+ return;
554
+ if (ldq == -1)
555
+ ldq = d;
556
+ if (ldb == -1)
557
+ ldb = d;
558
+ if (ldd == -1)
559
+ ldd = nb;
574
560
 
575
561
  // store in beginning of distance matrix to avoid malloc
576
- float *b_norms = dis;
562
+ float* b_norms = dis;
577
563
 
578
564
  #pragma omp parallel for
579
565
  for (int64_t i = 0; i < nb; i++)
580
- b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
566
+ b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
581
567
 
582
568
  #pragma omp parallel for
583
569
  for (int64_t i = 1; i < nq; i++) {
584
- float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
570
+ float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
585
571
  for (int64_t j = 0; j < nb; j++)
586
- dis[i * ldd + j] = q_norm + b_norms [j];
572
+ dis[i * ldd + j] = q_norm + b_norms[j];
587
573
  }
588
574
 
589
575
  {
590
- float q_norm = fvec_norm_L2sqr (xq, d);
576
+ float q_norm = fvec_norm_L2sqr(xq, d);
591
577
  for (int64_t j = 0; j < nb; j++)
592
578
  dis[j] += q_norm;
593
579
  }
@@ -596,22 +582,28 @@ void pairwise_L2sqr (int64_t d,
596
582
  FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
597
583
  float one = 1.0, minus_2 = -2.0;
598
584
 
599
- sgemm_ ("Transposed", "Not transposed",
600
- &nbi, &nqi, &di,
601
- &minus_2,
602
- xb, &ldbi,
603
- xq, &ldqi,
604
- &one, dis, &lddi);
585
+ sgemm_("Transposed",
586
+ "Not transposed",
587
+ &nbi,
588
+ &nqi,
589
+ &di,
590
+ &minus_2,
591
+ xb,
592
+ &ldbi,
593
+ xq,
594
+ &ldqi,
595
+ &one,
596
+ dis,
597
+ &lddi);
605
598
  }
606
-
607
599
  }
608
600
 
609
- void inner_product_to_L2sqr(float* __restrict dis,
610
- const float* nr1,
611
- const float* nr2,
612
- size_t n1, size_t n2)
613
- {
614
-
601
+ void inner_product_to_L2sqr(
602
+ float* __restrict dis,
603
+ const float* nr1,
604
+ const float* nr2,
605
+ size_t n1,
606
+ size_t n2) {
615
607
  #pragma omp parallel for
616
608
  for (int64_t j = 0; j < n1; j++) {
617
609
  float* disj = dis + j * n2;
@@ -620,5 +612,4 @@ void inner_product_to_L2sqr(float* __restrict dis,
620
612
  }
621
613
  }
622
614
 
623
-
624
615
  } // namespace faiss