faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -5,55 +5,51 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
- #include <faiss/gpu/utils/Timer.h>
10
8
  #include <faiss/gpu/utils/DeviceUtils.h>
9
+ #include <faiss/gpu/utils/Timer.h>
11
10
  #include <faiss/impl/FaissAssert.h>
12
11
  #include <chrono>
13
12
 
14
- namespace faiss { namespace gpu {
13
+ namespace faiss {
14
+ namespace gpu {
15
15
 
16
16
  KernelTimer::KernelTimer(cudaStream_t stream)
17
- : startEvent_(0),
18
- stopEvent_(0),
19
- stream_(stream),
20
- valid_(true) {
21
- CUDA_VERIFY(cudaEventCreate(&startEvent_));
22
- CUDA_VERIFY(cudaEventCreate(&stopEvent_));
23
-
24
- CUDA_VERIFY(cudaEventRecord(startEvent_, stream_));
17
+ : startEvent_(0), stopEvent_(0), stream_(stream), valid_(true) {
18
+ CUDA_VERIFY(cudaEventCreate(&startEvent_));
19
+ CUDA_VERIFY(cudaEventCreate(&stopEvent_));
20
+
21
+ CUDA_VERIFY(cudaEventRecord(startEvent_, stream_));
25
22
  }
26
23
 
27
24
  KernelTimer::~KernelTimer() {
28
- CUDA_VERIFY(cudaEventDestroy(startEvent_));
29
- CUDA_VERIFY(cudaEventDestroy(stopEvent_));
25
+ CUDA_VERIFY(cudaEventDestroy(startEvent_));
26
+ CUDA_VERIFY(cudaEventDestroy(stopEvent_));
30
27
  }
31
28
 
32
- float
33
- KernelTimer::elapsedMilliseconds() {
34
- FAISS_ASSERT(valid_);
29
+ float KernelTimer::elapsedMilliseconds() {
30
+ FAISS_ASSERT(valid_);
35
31
 
36
- CUDA_VERIFY(cudaEventRecord(stopEvent_, stream_));
37
- CUDA_VERIFY(cudaEventSynchronize(stopEvent_));
32
+ CUDA_VERIFY(cudaEventRecord(stopEvent_, stream_));
33
+ CUDA_VERIFY(cudaEventSynchronize(stopEvent_));
38
34
 
39
- auto time = 0.0f;
40
- CUDA_VERIFY(cudaEventElapsedTime(&time, startEvent_, stopEvent_));
41
- valid_ = false;
35
+ auto time = 0.0f;
36
+ CUDA_VERIFY(cudaEventElapsedTime(&time, startEvent_, stopEvent_));
37
+ valid_ = false;
42
38
 
43
- return time;
39
+ return time;
44
40
  }
45
41
 
46
42
  CpuTimer::CpuTimer() {
47
- start_ = std::chrono::steady_clock::now();
43
+ start_ = std::chrono::steady_clock::now();
48
44
  }
49
45
 
50
- float
51
- CpuTimer::elapsedMilliseconds() {
52
- auto end = std::chrono::steady_clock::now();
46
+ float CpuTimer::elapsedMilliseconds() {
47
+ auto end = std::chrono::steady_clock::now();
53
48
 
54
- std::chrono::duration<float, std::milli> duration = end - start_;
49
+ std::chrono::duration<float, std::milli> duration = end - start_;
55
50
 
56
- return duration.count();
51
+ return duration.count();
57
52
  }
58
53
 
59
- } } // namespace
54
+ } // namespace gpu
55
+ } // namespace faiss
@@ -5,48 +5,49 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <cuda_runtime.h>
12
11
  #include <chrono>
13
12
 
14
- namespace faiss { namespace gpu {
13
+ namespace faiss {
14
+ namespace gpu {
15
15
 
16
16
  /// Utility class for timing execution of a kernel
17
17
  class KernelTimer {
18
- public:
19
- /// Constructor starts the timer and adds an event into the current
20
- /// device stream
21
- KernelTimer(cudaStream_t stream = 0);
22
-
23
- /// Destructor releases event resources
24
- ~KernelTimer();
25
-
26
- /// Adds a stop event then synchronizes on the stop event to get the
27
- /// actual GPU-side kernel timings for any kernels launched in the
28
- /// current stream. Returns the number of milliseconds elapsed.
29
- /// Can only be called once.
30
- float elapsedMilliseconds();
31
-
32
- private:
33
- cudaEvent_t startEvent_;
34
- cudaEvent_t stopEvent_;
35
- cudaStream_t stream_;
36
- bool valid_;
18
+ public:
19
+ /// Constructor starts the timer and adds an event into the current
20
+ /// device stream
21
+ KernelTimer(cudaStream_t stream = 0);
22
+
23
+ /// Destructor releases event resources
24
+ ~KernelTimer();
25
+
26
+ /// Adds a stop event then synchronizes on the stop event to get the
27
+ /// actual GPU-side kernel timings for any kernels launched in the
28
+ /// current stream. Returns the number of milliseconds elapsed.
29
+ /// Can only be called once.
30
+ float elapsedMilliseconds();
31
+
32
+ private:
33
+ cudaEvent_t startEvent_;
34
+ cudaEvent_t stopEvent_;
35
+ cudaStream_t stream_;
36
+ bool valid_;
37
37
  };
38
38
 
39
39
  /// CPU wallclock elapsed timer
40
40
  class CpuTimer {
41
- public:
42
- /// Creates and starts a new timer
43
- CpuTimer();
41
+ public:
42
+ /// Creates and starts a new timer
43
+ CpuTimer();
44
44
 
45
- /// Returns elapsed time in milliseconds
46
- float elapsedMilliseconds();
45
+ /// Returns elapsed time in milliseconds
46
+ float elapsedMilliseconds();
47
47
 
48
- private:
49
- std::chrono::time_point<std::chrono::steady_clock> start_;
48
+ private:
49
+ std::chrono::time_point<std::chrono::steady_clock> start_;
50
50
  };
51
51
 
52
- } } // namespace
52
+ } // namespace gpu
53
+ } // namespace faiss
@@ -0,0 +1,503 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/AdditiveQuantizer.h>
11
+
12
+ #include <cstddef>
13
+ #include <cstdio>
14
+ #include <cstring>
15
+ #include <memory>
16
+ #include <random>
17
+
18
+ #include <algorithm>
19
+
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/utils/Heap.h>
22
+ #include <faiss/utils/distances.h>
23
+ #include <faiss/utils/hamming.h>
24
+ #include <faiss/utils/utils.h>
25
+
26
+ extern "C" {
27
+
28
+ // general matrix multiplication
29
+ int sgemm_(
30
+ const char* transa,
31
+ const char* transb,
32
+ FINTEGER* m,
33
+ FINTEGER* n,
34
+ FINTEGER* k,
35
+ const float* alpha,
36
+ const float* a,
37
+ FINTEGER* lda,
38
+ const float* b,
39
+ FINTEGER* ldb,
40
+ float* beta,
41
+ float* c,
42
+ FINTEGER* ldc);
43
+ }
44
+
45
+ namespace faiss {
46
+
47
+ AdditiveQuantizer::AdditiveQuantizer(
48
+ size_t d,
49
+ const std::vector<size_t>& nbits,
50
+ Search_type_t search_type)
51
+ : d(d),
52
+ M(nbits.size()),
53
+ nbits(nbits),
54
+ verbose(false),
55
+ is_trained(false),
56
+ search_type(search_type) {
57
+ norm_max = norm_min = NAN;
58
+ code_size = 0;
59
+ tot_bits = 0;
60
+ total_codebook_size = 0;
61
+ only_8bit = false;
62
+ set_derived_values();
63
+ }
64
+
65
+ AdditiveQuantizer::AdditiveQuantizer()
66
+ : AdditiveQuantizer(0, std::vector<size_t>()) {}
67
+
68
+ void AdditiveQuantizer::set_derived_values() {
69
+ tot_bits = 0;
70
+ only_8bit = true;
71
+ codebook_offsets.resize(M + 1, 0);
72
+ for (int i = 0; i < M; i++) {
73
+ int nbit = nbits[i];
74
+ size_t k = 1 << nbit;
75
+ codebook_offsets[i + 1] = codebook_offsets[i] + k;
76
+ tot_bits += nbit;
77
+ if (nbit != 0) {
78
+ only_8bit = false;
79
+ }
80
+ }
81
+ total_codebook_size = codebook_offsets[M];
82
+ switch (search_type) {
83
+ case ST_decompress:
84
+ case ST_LUT_nonorm:
85
+ case ST_norm_from_LUT:
86
+ break; // nothing to add
87
+ case ST_norm_float:
88
+ tot_bits += 32;
89
+ break;
90
+ case ST_norm_qint8:
91
+ case ST_norm_cqint8:
92
+ tot_bits += 8;
93
+ break;
94
+ case ST_norm_qint4:
95
+ case ST_norm_cqint4:
96
+ tot_bits += 4;
97
+ break;
98
+ }
99
+
100
+ // convert bits to bytes
101
+ code_size = (tot_bits + 7) / 8;
102
+ }
103
+
104
+ namespace {
105
+
106
+ // TODO
107
+ // https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0
108
+
109
+ uint8_t encode_qint8(float x, float amin, float amax) {
110
+ float x1 = (x - amin) / (amax - amin) * 256;
111
+ int32_t xi = int32_t(floor(x1));
112
+
113
+ return xi < 0 ? 0 : xi > 255 ? 255 : xi;
114
+ }
115
+
116
+ uint8_t encode_qint4(float x, float amin, float amax) {
117
+ float x1 = (x - amin) / (amax - amin) * 16;
118
+ int32_t xi = int32_t(floor(x1));
119
+
120
+ return xi < 0 ? 0 : xi > 15 ? 15 : xi;
121
+ }
122
+
123
+ float decode_qint8(uint8_t i, float amin, float amax) {
124
+ return (i + 0.5) / 256 * (amax - amin) + amin;
125
+ }
126
+
127
+ float decode_qint4(uint8_t i, float amin, float amax) {
128
+ return (i + 0.5) / 16 * (amax - amin) + amin;
129
+ }
130
+
131
+ } // anonymous namespace
132
+
133
+ uint32_t AdditiveQuantizer::encode_qcint(float x) const {
134
+ idx_t id;
135
+ qnorm.assign(idx_t(1), &x, &id, idx_t(1));
136
+ return uint32_t(id);
137
+ }
138
+
139
+ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
140
+ return qnorm.get_xb()[c];
141
+ }
142
+
143
+ void AdditiveQuantizer::pack_codes(
144
+ size_t n,
145
+ const int32_t* codes,
146
+ uint8_t* packed_codes,
147
+ int64_t ld_codes,
148
+ const float* norms) const {
149
+ if (ld_codes == -1) {
150
+ ld_codes = M;
151
+ }
152
+ std::vector<float> norm_buf;
153
+ if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
154
+ search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
155
+ search_type == ST_norm_cqint4) {
156
+ if (!norms) {
157
+ norm_buf.resize(n);
158
+ std::vector<float> x_recons(n * d);
159
+ decode_unpacked(codes, x_recons.data(), n, ld_codes);
160
+ fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
161
+ norms = norm_buf.data();
162
+ }
163
+ }
164
+ #pragma omp parallel for if (n > 1000)
165
+ for (int64_t i = 0; i < n; i++) {
166
+ const int32_t* codes1 = codes + i * ld_codes;
167
+ BitstringWriter bsw(packed_codes + i * code_size, code_size);
168
+ for (int m = 0; m < M; m++) {
169
+ bsw.write(codes1[m], nbits[m]);
170
+ }
171
+ switch (search_type) {
172
+ case ST_decompress:
173
+ case ST_LUT_nonorm:
174
+ case ST_norm_from_LUT:
175
+ break;
176
+ case ST_norm_float:
177
+ bsw.write(*(uint32_t*)&norms[i], 32);
178
+ break;
179
+ case ST_norm_qint8: {
180
+ uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
181
+ bsw.write(b, 8);
182
+ break;
183
+ }
184
+ case ST_norm_qint4: {
185
+ uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
186
+ bsw.write(b, 4);
187
+ break;
188
+ }
189
+ case ST_norm_cqint8: {
190
+ uint32_t b = encode_qcint(norms[i]);
191
+ bsw.write(b, 8);
192
+ break;
193
+ }
194
+ case ST_norm_cqint4: {
195
+ uint32_t b = encode_qcint(norms[i]);
196
+ bsw.write(b, 4);
197
+ break;
198
+ }
199
+ }
200
+ }
201
+ }
202
+
203
+ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
204
+ FAISS_THROW_IF_NOT_MSG(
205
+ is_trained, "The additive quantizer is not trained yet.");
206
+
207
+ // standard additive quantizer decoding
208
+ #pragma omp parallel for if (n > 1000)
209
+ for (int64_t i = 0; i < n; i++) {
210
+ BitstringReader bsr(code + i * code_size, code_size);
211
+ float* xi = x + i * d;
212
+ for (int m = 0; m < M; m++) {
213
+ int idx = bsr.read(nbits[m]);
214
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
215
+ if (m == 0) {
216
+ memcpy(xi, c, sizeof(*x) * d);
217
+ } else {
218
+ fvec_add(d, xi, c, xi);
219
+ }
220
+ }
221
+ }
222
+ }
223
+
224
+ void AdditiveQuantizer::decode_unpacked(
225
+ const int32_t* code,
226
+ float* x,
227
+ size_t n,
228
+ int64_t ld_codes) const {
229
+ FAISS_THROW_IF_NOT_MSG(
230
+ is_trained, "The additive quantizer is not trained yet.");
231
+
232
+ if (ld_codes == -1) {
233
+ ld_codes = M;
234
+ }
235
+
236
+ // standard additive quantizer decoding
237
+ #pragma omp parallel for if (n > 1000)
238
+ for (int64_t i = 0; i < n; i++) {
239
+ const int32_t* codesi = code + i * ld_codes;
240
+ float* xi = x + i * d;
241
+ for (int m = 0; m < M; m++) {
242
+ int idx = codesi[m];
243
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
244
+ if (m == 0) {
245
+ memcpy(xi, c, sizeof(*x) * d);
246
+ } else {
247
+ fvec_add(d, xi, c, xi);
248
+ }
249
+ }
250
+ }
251
+ }
252
+
253
+ AdditiveQuantizer::~AdditiveQuantizer() {}
254
+
255
+ /****************************************************************************
256
+ * Support for fast distance computations in centroids
257
+ ****************************************************************************/
258
+
259
+ void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
260
+ size_t ntotal = (size_t)1 << tot_bits;
261
+ // TODO: make tree of partial sums
262
+ #pragma omp parallel
263
+ {
264
+ std::vector<float> tmp(d);
265
+ #pragma omp for
266
+ for (int64_t i = 0; i < ntotal; i++) {
267
+ decode_64bit(i, tmp.data());
268
+ norms[i] = fvec_norm_L2sqr(tmp.data(), d);
269
+ }
270
+ }
271
+ }
272
+
273
+ void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
274
+ for (int m = 0; m < M; m++) {
275
+ idx_t idx = bits & (((size_t)1 << nbits[m]) - 1);
276
+ bits >>= nbits[m];
277
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
278
+ if (m == 0) {
279
+ memcpy(xi, c, sizeof(*xi) * d);
280
+ } else {
281
+ fvec_add(d, xi, c, xi);
282
+ }
283
+ }
284
+ }
285
+
286
+ void AdditiveQuantizer::compute_LUT(size_t n, const float* xq, float* LUT)
287
+ const {
288
+ // in all cases, it is large matrix multiplication
289
+
290
+ FINTEGER ncenti = total_codebook_size;
291
+ FINTEGER di = d;
292
+ FINTEGER nqi = n;
293
+ float one = 1, zero = 0;
294
+
295
+ sgemm_("Transposed",
296
+ "Not transposed",
297
+ &ncenti,
298
+ &nqi,
299
+ &di,
300
+ &one,
301
+ codebooks.data(),
302
+ &di,
303
+ xq,
304
+ &di,
305
+ &zero,
306
+ LUT,
307
+ &ncenti);
308
+ }
309
+
310
+ namespace {
311
+
312
+ void compute_inner_prod_with_LUT(
313
+ const AdditiveQuantizer& aq,
314
+ const float* LUT,
315
+ float* ips) {
316
+ size_t prev_size = 1;
317
+ for (int m = 0; m < aq.M; m++) {
318
+ const float* LUTm = LUT + aq.codebook_offsets[m];
319
+ int nb = aq.nbits[m];
320
+ size_t nc = (size_t)1 << nb;
321
+
322
+ if (m == 0) {
323
+ memcpy(ips, LUT, sizeof(*ips) * nc);
324
+ } else {
325
+ for (int64_t i = nc - 1; i >= 0; i--) {
326
+ float v = LUTm[i];
327
+ fvec_add(prev_size, ips, v, ips + i * prev_size);
328
+ }
329
+ }
330
+ prev_size *= nc;
331
+ }
332
+ }
333
+
334
+ } // anonymous namespace
335
+
336
+ void AdditiveQuantizer::knn_centroids_inner_product(
337
+ idx_t n,
338
+ const float* xq,
339
+ idx_t k,
340
+ float* distances,
341
+ idx_t* labels) const {
342
+ std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]);
343
+ compute_LUT(n, xq, LUT.get());
344
+ size_t ntotal = (size_t)1 << tot_bits;
345
+
346
+ #pragma omp parallel if (n > 100)
347
+ {
348
+ std::vector<float> dis(ntotal);
349
+ #pragma omp for
350
+ for (idx_t i = 0; i < n; i++) {
351
+ const float* LUTi = LUT.get() + i * total_codebook_size;
352
+ compute_inner_prod_with_LUT(*this, LUTi, dis.data());
353
+ float* distances_i = distances + i * k;
354
+ idx_t* labels_i = labels + i * k;
355
+ minheap_heapify(k, distances_i, labels_i);
356
+ minheap_addn(k, distances_i, labels_i, dis.data(), nullptr, ntotal);
357
+ minheap_reorder(k, distances_i, labels_i);
358
+ }
359
+ }
360
+ }
361
+
362
+ void AdditiveQuantizer::knn_centroids_L2(
363
+ idx_t n,
364
+ const float* xq,
365
+ idx_t k,
366
+ float* distances,
367
+ idx_t* labels,
368
+ const float* norms) const {
369
+ std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]);
370
+ compute_LUT(n, xq, LUT.get());
371
+ std::unique_ptr<float[]> q_norms(new float[n]);
372
+ fvec_norms_L2sqr(q_norms.get(), xq, d, n);
373
+ size_t ntotal = (size_t)1 << tot_bits;
374
+
375
+ #pragma omp parallel if (n > 100)
376
+ {
377
+ std::vector<float> dis(ntotal);
378
+ #pragma omp for
379
+ for (idx_t i = 0; i < n; i++) {
380
+ const float* LUTi = LUT.get() + i * total_codebook_size;
381
+ float* distances_i = distances + i * k;
382
+ idx_t* labels_i = labels + i * k;
383
+
384
+ compute_inner_prod_with_LUT(*this, LUTi, dis.data());
385
+
386
+ // update distances using
387
+ // ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * <x,y>
388
+
389
+ maxheap_heapify(k, distances_i, labels_i);
390
+ for (idx_t j = 0; j < ntotal; j++) {
391
+ float disj = q_norms[i] + norms[j] - 2 * dis[j];
392
+ if (disj < distances_i[0]) {
393
+ heap_replace_top<CMax<float, int64_t>>(
394
+ k, distances_i, labels_i, disj, j);
395
+ }
396
+ }
397
+ maxheap_reorder(k, distances_i, labels_i);
398
+ }
399
+ }
400
+ }
401
+
402
+ /****************************************************************************
403
+ * Support for fast distance computations in codes
404
+ ****************************************************************************/
405
+
406
+ namespace {
407
+
408
+ float accumulate_IPs(
409
+ const AdditiveQuantizer& aq,
410
+ BitstringReader& bs,
411
+ const uint8_t* codes,
412
+ const float* LUT) {
413
+ float accu = 0;
414
+ for (int m = 0; m < aq.M; m++) {
415
+ size_t nbit = aq.nbits[m];
416
+ int idx = bs.read(nbit);
417
+ accu += LUT[idx];
418
+ LUT += (uint64_t)1 << nbit;
419
+ }
420
+ return accu;
421
+ }
422
+
423
+ } // anonymous namespace
424
+
425
+ template <>
426
+ float AdditiveQuantizer::
427
+ compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
428
+ const uint8_t* codes,
429
+ const float* LUT) const {
430
+ BitstringReader bs(codes, code_size);
431
+ return accumulate_IPs(*this, bs, codes, LUT);
432
+ }
433
+
434
+ template <>
435
+ float AdditiveQuantizer::
436
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>(
437
+ const uint8_t* codes,
438
+ const float* LUT) const {
439
+ BitstringReader bs(codes, code_size);
440
+ return -accumulate_IPs(*this, bs, codes, LUT);
441
+ }
442
+
443
+ template <>
444
+ float AdditiveQuantizer::
445
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>(
446
+ const uint8_t* codes,
447
+ const float* LUT) const {
448
+ BitstringReader bs(codes, code_size);
449
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
450
+ uint32_t norm_i = bs.read(32);
451
+ float norm2 = *(float*)&norm_i;
452
+ return norm2 - 2 * accu;
453
+ }
454
+
455
+ template <>
456
+ float AdditiveQuantizer::
457
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
458
+ const uint8_t* codes,
459
+ const float* LUT) const {
460
+ BitstringReader bs(codes, code_size);
461
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
462
+ uint32_t norm_i = bs.read(8);
463
+ float norm2 = decode_qcint(norm_i);
464
+ return norm2 - 2 * accu;
465
+ }
466
+
467
+ template <>
468
+ float AdditiveQuantizer::
469
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>(
470
+ const uint8_t* codes,
471
+ const float* LUT) const {
472
+ BitstringReader bs(codes, code_size);
473
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
474
+ uint32_t norm_i = bs.read(4);
475
+ float norm2 = decode_qcint(norm_i);
476
+ return norm2 - 2 * accu;
477
+ }
478
+
479
+ template <>
480
+ float AdditiveQuantizer::
481
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>(
482
+ const uint8_t* codes,
483
+ const float* LUT) const {
484
+ BitstringReader bs(codes, code_size);
485
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
486
+ uint32_t norm_i = bs.read(8);
487
+ float norm2 = decode_qint8(norm_i, norm_min, norm_max);
488
+ return norm2 - 2 * accu;
489
+ }
490
+
491
+ template <>
492
+ float AdditiveQuantizer::
493
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>(
494
+ const uint8_t* codes,
495
+ const float* LUT) const {
496
+ BitstringReader bs(codes, code_size);
497
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
498
+ uint32_t norm_i = bs.read(4);
499
+ float norm2 = decode_qint4(norm_i, norm_min, norm_max);
500
+ return norm2 - 2 * accu;
501
+ }
502
+
503
+ } // namespace faiss