faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,126 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #include <faiss/utils/WorkerThread.h>
10
+ #include <faiss/impl/FaissAssert.h>
11
+ #include <exception>
12
+
13
+ namespace faiss {
14
+
15
+ namespace {
16
+
17
+ // Captures any exceptions thrown by the lambda and returns them via the promise
18
+ void runCallback(std::function<void()>& fn,
19
+ std::promise<bool>& promise) {
20
+ try {
21
+ fn();
22
+ promise.set_value(true);
23
+ } catch (...) {
24
+ promise.set_exception(std::current_exception());
25
+ }
26
+ }
27
+
28
+ } // namespace
29
+
30
+ WorkerThread::WorkerThread() :
31
+ wantStop_(false) {
32
+ startThread();
33
+
34
+ // Make sure that the thread has started before continuing
35
+ add([](){}).get();
36
+ }
37
+
38
+ WorkerThread::~WorkerThread() {
39
+ stop();
40
+ waitForThreadExit();
41
+ }
42
+
43
+ void
44
+ WorkerThread::startThread() {
45
+ thread_ = std::thread([this](){ threadMain(); });
46
+ }
47
+
48
+ void
49
+ WorkerThread::stop() {
50
+ std::lock_guard<std::mutex> guard(mutex_);
51
+
52
+ wantStop_ = true;
53
+ monitor_.notify_one();
54
+ }
55
+
56
+ std::future<bool>
57
+ WorkerThread::add(std::function<void()> f) {
58
+ std::lock_guard<std::mutex> guard(mutex_);
59
+
60
+ if (wantStop_) {
61
+ // The timer thread has been stopped, or we want to stop; we can't
62
+ // schedule anything else
63
+ std::promise<bool> p;
64
+ auto fut = p.get_future();
65
+
66
+ // did not execute
67
+ p.set_value(false);
68
+ return fut;
69
+ }
70
+
71
+ auto pr = std::promise<bool>();
72
+ auto fut = pr.get_future();
73
+
74
+ queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
75
+
76
+ // Wake up our thread
77
+ monitor_.notify_one();
78
+ return fut;
79
+ }
80
+
81
+ void
82
+ WorkerThread::threadMain() {
83
+ threadLoop();
84
+
85
+ // Call all pending tasks
86
+ FAISS_ASSERT(wantStop_);
87
+
88
+ // flush all pending operations
89
+ for (auto& f : queue_) {
90
+ runCallback(f.first, f.second);
91
+ }
92
+ }
93
+
94
+ void
95
+ WorkerThread::threadLoop() {
96
+ while (true) {
97
+ std::pair<std::function<void()>, std::promise<bool>> data;
98
+
99
+ {
100
+ std::unique_lock<std::mutex> lock(mutex_);
101
+
102
+ while (!wantStop_ && queue_.empty()) {
103
+ monitor_.wait(lock);
104
+ }
105
+
106
+ if (wantStop_) {
107
+ return;
108
+ }
109
+
110
+ data = std::move(queue_.front());
111
+ queue_.pop_front();
112
+ }
113
+
114
+ runCallback(data.first, data.second);
115
+ }
116
+ }
117
+
118
+ void
119
+ WorkerThread::waitForThreadExit() {
120
+ try {
121
+ thread_.join();
122
+ } catch (...) {
123
+ }
124
+ }
125
+
126
+ } // namespace
@@ -0,0 +1,61 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #pragma once
10
+
11
+ #include <condition_variable>
12
+ #include <future>
13
+ #include <deque>
14
+ #include <thread>
15
+
16
+ namespace faiss {
17
+
18
+ class WorkerThread {
19
+ public:
20
+ WorkerThread();
21
+
22
+ /// Stops and waits for the worker thread to exit, flushing all
23
+ /// pending lambdas
24
+ ~WorkerThread();
25
+
26
+ /// Request that the worker thread stop itself
27
+ void stop();
28
+
29
+ /// Blocking waits in the current thread for the worker thread to
30
+ /// stop
31
+ void waitForThreadExit();
32
+
33
+ /// Adds a lambda to run on the worker thread; returns a future that
34
+ /// can be used to block on its completion.
35
+ /// Future status is `true` if the lambda was run in the worker
36
+ /// thread; `false` if it was not run, because the worker thread is
37
+ /// exiting or has exited.
38
+ std::future<bool> add(std::function<void()> f);
39
+
40
+ private:
41
+ void startThread();
42
+ void threadMain();
43
+ void threadLoop();
44
+
45
+ /// Thread that all queued lambdas are run on
46
+ std::thread thread_;
47
+
48
+ /// Mutex for the queue and exit status
49
+ std::mutex mutex_;
50
+
51
+ /// Monitor for the exit status and the queue
52
+ std::condition_variable monitor_;
53
+
54
+ /// Whether or not we want the thread to exit
55
+ bool wantStop_;
56
+
57
+ /// Queue of pending lambdas to call
58
+ std::deque<std::pair<std::function<void()>, std::promise<bool>>> queue_;
59
+ };
60
+
61
+ } // namespace
@@ -0,0 +1,765 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/utils/distances.h>
11
+
12
+ #include <cstdio>
13
+ #include <cassert>
14
+ #include <cstring>
15
+ #include <cmath>
16
+
17
+ #include <omp.h>
18
+
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
21
+
22
+
23
+
24
+ #ifndef FINTEGER
25
+ #define FINTEGER long
26
+ #endif
27
+
28
+
29
+ extern "C" {
30
+
31
+ /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
32
+
33
+ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
34
+ n, FINTEGER *k, const float *alpha, const float *a,
35
+ FINTEGER *lda, const float *b, FINTEGER *
36
+ ldb, float *beta, float *c, FINTEGER *ldc);
37
+
38
+ /* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
39
+
40
+ int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
41
+ float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
42
+
43
+ int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
44
+ const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
45
+ float *beta, float *y, FINTEGER *incy);
46
+
47
+ }
48
+
49
+
50
+ namespace faiss {
51
+
52
+
53
+
54
+ /***************************************************************************
55
+ * Matrix/vector ops
56
+ ***************************************************************************/
57
+
58
+
59
+
60
+ /* Compute the inner product between a vector x and
61
+ a set of ny vectors y.
62
+ These functions are not intended to replace BLAS matrix-matrix, as they
63
+ would be significantly less efficient in this case. */
64
+ void fvec_inner_products_ny (float * ip,
65
+ const float * x,
66
+ const float * y,
67
+ size_t d, size_t ny)
68
+ {
69
+ // Not sure which one is fastest
70
+ #if 0
71
+ {
72
+ FINTEGER di = d;
73
+ FINTEGER nyi = ny;
74
+ float one = 1.0, zero = 0.0;
75
+ FINTEGER onei = 1;
76
+ sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
77
+ }
78
+ #endif
79
+ for (size_t i = 0; i < ny; i++) {
80
+ ip[i] = fvec_inner_product (x, y, d);
81
+ y += d;
82
+ }
83
+ }
84
+
85
+
86
+
87
+
88
+
89
+ /* Compute the L2 norm of a set of nx vectors */
90
+ void fvec_norms_L2 (float * __restrict nr,
91
+ const float * __restrict x,
92
+ size_t d, size_t nx)
93
+ {
94
+
95
+ #pragma omp parallel for
96
+ for (size_t i = 0; i < nx; i++) {
97
+ nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
98
+ }
99
+ }
100
+
101
+ void fvec_norms_L2sqr (float * __restrict nr,
102
+ const float * __restrict x,
103
+ size_t d, size_t nx)
104
+ {
105
+ #pragma omp parallel for
106
+ for (size_t i = 0; i < nx; i++)
107
+ nr[i] = fvec_norm_L2sqr (x + i * d, d);
108
+ }
109
+
110
+
111
+
112
+ void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
113
+ {
114
+ #pragma omp parallel for
115
+ for (size_t i = 0; i < nx; i++) {
116
+ float * __restrict xi = x + i * d;
117
+
118
+ float nr = fvec_norm_L2sqr (xi, d);
119
+
120
+ if (nr > 0) {
121
+ size_t j;
122
+ const float inv_nr = 1.0 / sqrtf (nr);
123
+ for (j = 0; j < d; j++)
124
+ xi[j] *= inv_nr;
125
+ }
126
+ }
127
+ }
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+ /***************************************************************************
141
+ * KNN functions
142
+ ***************************************************************************/
143
+
144
+
145
+
146
+ /* Find the nearest neighbors for nx queries in a set of ny vectors */
147
+ static void knn_inner_product_sse (const float * x,
148
+ const float * y,
149
+ size_t d, size_t nx, size_t ny,
150
+ float_minheap_array_t * res)
151
+ {
152
+ size_t k = res->k;
153
+ size_t check_period = InterruptCallback::get_period_hint (ny * d);
154
+
155
+ check_period *= omp_get_max_threads();
156
+
157
+ for (size_t i0 = 0; i0 < nx; i0 += check_period) {
158
+ size_t i1 = std::min(i0 + check_period, nx);
159
+
160
+ #pragma omp parallel for
161
+ for (size_t i = i0; i < i1; i++) {
162
+ const float * x_i = x + i * d;
163
+ const float * y_j = y;
164
+
165
+ float * __restrict simi = res->get_val(i);
166
+ int64_t * __restrict idxi = res->get_ids (i);
167
+
168
+ minheap_heapify (k, simi, idxi);
169
+
170
+ for (size_t j = 0; j < ny; j++) {
171
+ float ip = fvec_inner_product (x_i, y_j, d);
172
+
173
+ if (ip > simi[0]) {
174
+ minheap_pop (k, simi, idxi);
175
+ minheap_push (k, simi, idxi, ip, j);
176
+ }
177
+ y_j += d;
178
+ }
179
+ minheap_reorder (k, simi, idxi);
180
+ }
181
+ InterruptCallback::check ();
182
+ }
183
+
184
+ }
185
+
186
+ static void knn_L2sqr_sse (
187
+ const float * x,
188
+ const float * y,
189
+ size_t d, size_t nx, size_t ny,
190
+ float_maxheap_array_t * res)
191
+ {
192
+ size_t k = res->k;
193
+
194
+ size_t check_period = InterruptCallback::get_period_hint (ny * d);
195
+ check_period *= omp_get_max_threads();
196
+
197
+ for (size_t i0 = 0; i0 < nx; i0 += check_period) {
198
+ size_t i1 = std::min(i0 + check_period, nx);
199
+
200
+ #pragma omp parallel for
201
+ for (size_t i = i0; i < i1; i++) {
202
+ const float * x_i = x + i * d;
203
+ const float * y_j = y;
204
+ size_t j;
205
+ float * simi = res->get_val(i);
206
+ int64_t * idxi = res->get_ids (i);
207
+
208
+ maxheap_heapify (k, simi, idxi);
209
+ for (j = 0; j < ny; j++) {
210
+ float disij = fvec_L2sqr (x_i, y_j, d);
211
+
212
+ if (disij < simi[0]) {
213
+ maxheap_pop (k, simi, idxi);
214
+ maxheap_push (k, simi, idxi, disij, j);
215
+ }
216
+ y_j += d;
217
+ }
218
+ maxheap_reorder (k, simi, idxi);
219
+ }
220
+ InterruptCallback::check ();
221
+ }
222
+
223
+ }
224
+
225
+
226
+ /** Find the nearest neighbors for nx queries in a set of ny vectors */
227
+ static void knn_inner_product_blas (
228
+ const float * x,
229
+ const float * y,
230
+ size_t d, size_t nx, size_t ny,
231
+ float_minheap_array_t * res)
232
+ {
233
+ res->heapify ();
234
+
235
+ // BLAS does not like empty matrices
236
+ if (nx == 0 || ny == 0) return;
237
+
238
+ /* block sizes */
239
+ const size_t bs_x = 4096, bs_y = 1024;
240
+ // const size_t bs_x = 16, bs_y = 16;
241
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
242
+
243
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
244
+ size_t i1 = i0 + bs_x;
245
+ if(i1 > nx) i1 = nx;
246
+
247
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
248
+ size_t j1 = j0 + bs_y;
249
+ if (j1 > ny) j1 = ny;
250
+ /* compute the actual dot products */
251
+ {
252
+ float one = 1, zero = 0;
253
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
254
+ sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
255
+ y + j0 * d, &di,
256
+ x + i0 * d, &di, &zero,
257
+ ip_block.get(), &nyi);
258
+ }
259
+
260
+ /* collect maxima */
261
+ res->addn (j1 - j0, ip_block.get(), j0, i0, i1 - i0);
262
+ }
263
+ InterruptCallback::check ();
264
+ }
265
+ res->reorder ();
266
+ }
267
+
268
+ // distance correction is an operator that can be applied to transform
269
+ // the distances
270
+ template<class DistanceCorrection>
271
+ static void knn_L2sqr_blas (const float * x,
272
+ const float * y,
273
+ size_t d, size_t nx, size_t ny,
274
+ float_maxheap_array_t * res,
275
+ const DistanceCorrection &corr)
276
+ {
277
+ res->heapify ();
278
+
279
+ // BLAS does not like empty matrices
280
+ if (nx == 0 || ny == 0) return;
281
+
282
+ size_t k = res->k;
283
+
284
+ /* block sizes */
285
+ const size_t bs_x = 4096, bs_y = 1024;
286
+ // const size_t bs_x = 16, bs_y = 16;
287
+ float *ip_block = new float[bs_x * bs_y];
288
+ float *x_norms = new float[nx];
289
+ float *y_norms = new float[ny];
290
+ ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
291
+
292
+ fvec_norms_L2sqr (x_norms, x, d, nx);
293
+ fvec_norms_L2sqr (y_norms, y, d, ny);
294
+
295
+
296
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
297
+ size_t i1 = i0 + bs_x;
298
+ if(i1 > nx) i1 = nx;
299
+
300
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
301
+ size_t j1 = j0 + bs_y;
302
+ if (j1 > ny) j1 = ny;
303
+ /* compute the actual dot products */
304
+ {
305
+ float one = 1, zero = 0;
306
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
307
+ sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
308
+ y + j0 * d, &di,
309
+ x + i0 * d, &di, &zero,
310
+ ip_block, &nyi);
311
+ }
312
+
313
+ /* collect minima */
314
+ #pragma omp parallel for
315
+ for (size_t i = i0; i < i1; i++) {
316
+ float * __restrict simi = res->get_val(i);
317
+ int64_t * __restrict idxi = res->get_ids (i);
318
+ const float *ip_line = ip_block + (i - i0) * (j1 - j0);
319
+
320
+ for (size_t j = j0; j < j1; j++) {
321
+ float ip = *ip_line++;
322
+ float dis = x_norms[i] + y_norms[j] - 2 * ip;
323
+
324
+ // negative values can occur for identical vectors
325
+ // due to roundoff errors
326
+ if (dis < 0) dis = 0;
327
+
328
+ dis = corr (dis, i, j);
329
+
330
+ if (dis < simi[0]) {
331
+ maxheap_pop (k, simi, idxi);
332
+ maxheap_push (k, simi, idxi, dis, j);
333
+ }
334
+ }
335
+ }
336
+ }
337
+ InterruptCallback::check ();
338
+ }
339
+ res->reorder ();
340
+
341
+ }
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
351
+ /*******************************************************
352
+ * KNN driver functions
353
+ *******************************************************/
354
+
355
+ int distance_compute_blas_threshold = 20;
356
+
357
+ void knn_inner_product (const float * x,
358
+ const float * y,
359
+ size_t d, size_t nx, size_t ny,
360
+ float_minheap_array_t * res)
361
+ {
362
+ if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
363
+ knn_inner_product_sse (x, y, d, nx, ny, res);
364
+ } else {
365
+ knn_inner_product_blas (x, y, d, nx, ny, res);
366
+ }
367
+ }
368
+
369
+
370
+
371
+ struct NopDistanceCorrection {
372
+ float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
373
+ return dis;
374
+ }
375
+ };
376
+
377
+ void knn_L2sqr (const float * x,
378
+ const float * y,
379
+ size_t d, size_t nx, size_t ny,
380
+ float_maxheap_array_t * res)
381
+ {
382
+ if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
383
+ knn_L2sqr_sse (x, y, d, nx, ny, res);
384
+ } else {
385
+ NopDistanceCorrection nop;
386
+ knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
387
+ }
388
+ }
389
+
390
+ struct BaseShiftDistanceCorrection {
391
+ const float *base_shift;
392
+ float operator()(float dis, size_t /*qno*/, size_t bno) const {
393
+ return dis - base_shift[bno];
394
+ }
395
+ };
396
+
397
+ void knn_L2sqr_base_shift (
398
+ const float * x,
399
+ const float * y,
400
+ size_t d, size_t nx, size_t ny,
401
+ float_maxheap_array_t * res,
402
+ const float *base_shift)
403
+ {
404
+ BaseShiftDistanceCorrection corr = {base_shift};
405
+ knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
406
+ }
407
+
408
+
409
+
410
+ /***************************************************************************
411
+ * compute a subset of distances
412
+ ***************************************************************************/
413
+
414
+ /* compute the inner product between x and a subset y of ny vectors,
415
+ whose indices are given by idy. */
416
+ void fvec_inner_products_by_idx (float * __restrict ip,
417
+ const float * x,
418
+ const float * y,
419
+ const int64_t * __restrict ids, /* for y vecs */
420
+ size_t d, size_t nx, size_t ny)
421
+ {
422
+ #pragma omp parallel for
423
+ for (size_t j = 0; j < nx; j++) {
424
+ const int64_t * __restrict idsj = ids + j * ny;
425
+ const float * xj = x + j * d;
426
+ float * __restrict ipj = ip + j * ny;
427
+ for (size_t i = 0; i < ny; i++) {
428
+ if (idsj[i] < 0)
429
+ continue;
430
+ ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
431
+ }
432
+ }
433
+ }
434
+
435
+
436
+
437
+ /* compute the inner product between x and a subset y of ny vectors,
438
+ whose indices are given by idy. */
439
+ void fvec_L2sqr_by_idx (float * __restrict dis,
440
+ const float * x,
441
+ const float * y,
442
+ const int64_t * __restrict ids, /* ids of y vecs */
443
+ size_t d, size_t nx, size_t ny)
444
+ {
445
+ #pragma omp parallel for
446
+ for (size_t j = 0; j < nx; j++) {
447
+ const int64_t * __restrict idsj = ids + j * ny;
448
+ const float * xj = x + j * d;
449
+ float * __restrict disj = dis + j * ny;
450
+ for (size_t i = 0; i < ny; i++) {
451
+ if (idsj[i] < 0)
452
+ continue;
453
+ disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
454
+ }
455
+ }
456
+ }
457
+
458
+ void pairwise_indexed_L2sqr (
459
+ size_t d, size_t n,
460
+ const float * x, const int64_t *ix,
461
+ const float * y, const int64_t *iy,
462
+ float *dis)
463
+ {
464
+ #pragma omp parallel for
465
+ for (size_t j = 0; j < n; j++) {
466
+ if (ix[j] >= 0 && iy[j] >= 0) {
467
+ dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
468
+ }
469
+ }
470
+ }
471
+
472
+ void pairwise_indexed_inner_product (
473
+ size_t d, size_t n,
474
+ const float * x, const int64_t *ix,
475
+ const float * y, const int64_t *iy,
476
+ float *dis)
477
+ {
478
+ #pragma omp parallel for
479
+ for (size_t j = 0; j < n; j++) {
480
+ if (ix[j] >= 0 && iy[j] >= 0) {
481
+ dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
482
+ }
483
+ }
484
+ }
485
+
486
+
487
+ /* Find the nearest neighbors for nx queries in a set of ny vectors
488
+ indexed by ids. May be useful for re-ranking a pre-selected vector list */
489
+ void knn_inner_products_by_idx (const float * x,
490
+ const float * y,
491
+ const int64_t * ids,
492
+ size_t d, size_t nx, size_t ny,
493
+ float_minheap_array_t * res)
494
+ {
495
+ size_t k = res->k;
496
+
497
+ #pragma omp parallel for
498
+ for (size_t i = 0; i < nx; i++) {
499
+ const float * x_ = x + i * d;
500
+ const int64_t * idsi = ids + i * ny;
501
+ size_t j;
502
+ float * __restrict simi = res->get_val(i);
503
+ int64_t * __restrict idxi = res->get_ids (i);
504
+ minheap_heapify (k, simi, idxi);
505
+
506
+ for (j = 0; j < ny; j++) {
507
+ if (idsi[j] < 0) break;
508
+ float ip = fvec_inner_product (x_, y + d * idsi[j], d);
509
+
510
+ if (ip > simi[0]) {
511
+ minheap_pop (k, simi, idxi);
512
+ minheap_push (k, simi, idxi, ip, idsi[j]);
513
+ }
514
+ }
515
+ minheap_reorder (k, simi, idxi);
516
+ }
517
+
518
+ }
519
+
520
+ void knn_L2sqr_by_idx (const float * x,
521
+ const float * y,
522
+ const int64_t * __restrict ids,
523
+ size_t d, size_t nx, size_t ny,
524
+ float_maxheap_array_t * res)
525
+ {
526
+ size_t k = res->k;
527
+
528
+ #pragma omp parallel for
529
+ for (size_t i = 0; i < nx; i++) {
530
+ const float * x_ = x + i * d;
531
+ const int64_t * __restrict idsi = ids + i * ny;
532
+ float * __restrict simi = res->get_val(i);
533
+ int64_t * __restrict idxi = res->get_ids (i);
534
+ maxheap_heapify (res->k, simi, idxi);
535
+ for (size_t j = 0; j < ny; j++) {
536
+ float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
537
+
538
+ if (disij < simi[0]) {
539
+ maxheap_pop (k, simi, idxi);
540
+ maxheap_push (k, simi, idxi, disij, idsi[j]);
541
+ }
542
+ }
543
+ maxheap_reorder (res->k, simi, idxi);
544
+ }
545
+
546
+ }
547
+
548
+
549
+
550
+
551
+
552
+ /***************************************************************************
553
+ * Range search
554
+ ***************************************************************************/
555
+
556
+ /** Find the nearest neighbors for nx queries in a set of ny vectors
557
+ * compute_l2 = compute pairwise squared L2 distance rather than inner prod
558
+ */
559
+ template <bool compute_l2>
560
+ static void range_search_blas (
561
+ const float * x,
562
+ const float * y,
563
+ size_t d, size_t nx, size_t ny,
564
+ float radius,
565
+ RangeSearchResult *result)
566
+ {
567
+
568
+ // BLAS does not like empty matrices
569
+ if (nx == 0 || ny == 0) return;
570
+
571
+ /* block sizes */
572
+ const size_t bs_x = 4096, bs_y = 1024;
573
+ // const size_t bs_x = 16, bs_y = 16;
574
+ float *ip_block = new float[bs_x * bs_y];
575
+ ScopeDeleter<float> del0(ip_block);
576
+
577
+ float *x_norms = nullptr, *y_norms = nullptr;
578
+ ScopeDeleter<float> del1, del2;
579
+ if (compute_l2) {
580
+ x_norms = new float[nx];
581
+ del1.set (x_norms);
582
+ fvec_norms_L2sqr (x_norms, x, d, nx);
583
+
584
+ y_norms = new float[ny];
585
+ del2.set (y_norms);
586
+ fvec_norms_L2sqr (y_norms, y, d, ny);
587
+ }
588
+
589
+ std::vector <RangeSearchPartialResult *> partial_results;
590
+
591
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
592
+ size_t j1 = j0 + bs_y;
593
+ if (j1 > ny) j1 = ny;
594
+ RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
595
+ partial_results.push_back (pres);
596
+
597
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
598
+ size_t i1 = i0 + bs_x;
599
+ if(i1 > nx) i1 = nx;
600
+
601
+ /* compute the actual dot products */
602
+ {
603
+ float one = 1, zero = 0;
604
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
605
+ sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
606
+ y + j0 * d, &di,
607
+ x + i0 * d, &di, &zero,
608
+ ip_block, &nyi);
609
+ }
610
+
611
+
612
+ for (size_t i = i0; i < i1; i++) {
613
+ const float *ip_line = ip_block + (i - i0) * (j1 - j0);
614
+
615
+ RangeQueryResult & qres = pres->new_result (i);
616
+
617
+ for (size_t j = j0; j < j1; j++) {
618
+ float ip = *ip_line++;
619
+ if (compute_l2) {
620
+ float dis = x_norms[i] + y_norms[j] - 2 * ip;
621
+ if (dis < radius) {
622
+ qres.add (dis, j);
623
+ }
624
+ } else {
625
+ if (ip > radius) {
626
+ qres.add (ip, j);
627
+ }
628
+ }
629
+ }
630
+ }
631
+ }
632
+ InterruptCallback::check ();
633
+ }
634
+
635
+ RangeSearchPartialResult::merge (partial_results);
636
+ }
637
+
638
+
639
+ template <bool compute_l2>
640
+ static void range_search_sse (const float * x,
641
+ const float * y,
642
+ size_t d, size_t nx, size_t ny,
643
+ float radius,
644
+ RangeSearchResult *res)
645
+ {
646
+ FAISS_THROW_IF_NOT (d % 4 == 0);
647
+
648
+ #pragma omp parallel
649
+ {
650
+ RangeSearchPartialResult pres (res);
651
+
652
+ #pragma omp for
653
+ for (size_t i = 0; i < nx; i++) {
654
+ const float * x_ = x + i * d;
655
+ const float * y_ = y;
656
+ size_t j;
657
+
658
+ RangeQueryResult & qres = pres.new_result (i);
659
+
660
+ for (j = 0; j < ny; j++) {
661
+ if (compute_l2) {
662
+ float disij = fvec_L2sqr (x_, y_, d);
663
+ if (disij < radius) {
664
+ qres.add (disij, j);
665
+ }
666
+ } else {
667
+ float ip = fvec_inner_product (x_, y_, d);
668
+ if (ip > radius) {
669
+ qres.add (ip, j);
670
+ }
671
+ }
672
+ y_ += d;
673
+ }
674
+
675
+ }
676
+ pres.finalize ();
677
+ }
678
+
679
+ // check just at the end because the use case is typically just
680
+ // when the nb of queries is low.
681
+ InterruptCallback::check();
682
+ }
683
+
684
+
685
+
686
+
687
+
688
+ void range_search_L2sqr (
689
+ const float * x,
690
+ const float * y,
691
+ size_t d, size_t nx, size_t ny,
692
+ float radius,
693
+ RangeSearchResult *res)
694
+ {
695
+
696
+ if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
697
+ range_search_sse<true> (x, y, d, nx, ny, radius, res);
698
+ } else {
699
+ range_search_blas<true> (x, y, d, nx, ny, radius, res);
700
+ }
701
+ }
702
+
703
+ void range_search_inner_product (
704
+ const float * x,
705
+ const float * y,
706
+ size_t d, size_t nx, size_t ny,
707
+ float radius,
708
+ RangeSearchResult *res)
709
+ {
710
+
711
+ if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
712
+ range_search_sse<false> (x, y, d, nx, ny, radius, res);
713
+ } else {
714
+ range_search_blas<false> (x, y, d, nx, ny, radius, res);
715
+ }
716
+ }
717
+
718
+
719
+ void pairwise_L2sqr (int64_t d,
720
+ int64_t nq, const float *xq,
721
+ int64_t nb, const float *xb,
722
+ float *dis,
723
+ int64_t ldq, int64_t ldb, int64_t ldd)
724
+ {
725
+ if (nq == 0 || nb == 0) return;
726
+ if (ldq == -1) ldq = d;
727
+ if (ldb == -1) ldb = d;
728
+ if (ldd == -1) ldd = nb;
729
+
730
+ // store in beginning of distance matrix to avoid malloc
731
+ float *b_norms = dis;
732
+
733
+ #pragma omp parallel for
734
+ for (int64_t i = 0; i < nb; i++)
735
+ b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
736
+
737
+ #pragma omp parallel for
738
+ for (int64_t i = 1; i < nq; i++) {
739
+ float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
740
+ for (int64_t j = 0; j < nb; j++)
741
+ dis[i * ldd + j] = q_norm + b_norms [j];
742
+ }
743
+
744
+ {
745
+ float q_norm = fvec_norm_L2sqr (xq, d);
746
+ for (int64_t j = 0; j < nb; j++)
747
+ dis[j] += q_norm;
748
+ }
749
+
750
+ {
751
+ FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
752
+ float one = 1.0, minus_2 = -2.0;
753
+
754
+ sgemm_ ("Transposed", "Not transposed",
755
+ &nbi, &nqi, &di,
756
+ &minus_2,
757
+ xb, &ldbi,
758
+ xq, &ldqi,
759
+ &one, dis, &lddi);
760
+ }
761
+
762
+ }
763
+
764
+
765
+ } // namespace faiss