faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -5,55 +5,51 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
- #include <faiss/gpu/utils/Timer.h>
10
8
  #include <faiss/gpu/utils/DeviceUtils.h>
9
+ #include <faiss/gpu/utils/Timer.h>
11
10
  #include <faiss/impl/FaissAssert.h>
12
11
  #include <chrono>
13
12
 
14
- namespace faiss { namespace gpu {
13
+ namespace faiss {
14
+ namespace gpu {
15
15
 
16
16
  KernelTimer::KernelTimer(cudaStream_t stream)
17
- : startEvent_(0),
18
- stopEvent_(0),
19
- stream_(stream),
20
- valid_(true) {
21
- CUDA_VERIFY(cudaEventCreate(&startEvent_));
22
- CUDA_VERIFY(cudaEventCreate(&stopEvent_));
23
-
24
- CUDA_VERIFY(cudaEventRecord(startEvent_, stream_));
17
+ : startEvent_(0), stopEvent_(0), stream_(stream), valid_(true) {
18
+ CUDA_VERIFY(cudaEventCreate(&startEvent_));
19
+ CUDA_VERIFY(cudaEventCreate(&stopEvent_));
20
+
21
+ CUDA_VERIFY(cudaEventRecord(startEvent_, stream_));
25
22
  }
26
23
 
27
24
  KernelTimer::~KernelTimer() {
28
- CUDA_VERIFY(cudaEventDestroy(startEvent_));
29
- CUDA_VERIFY(cudaEventDestroy(stopEvent_));
25
+ CUDA_VERIFY(cudaEventDestroy(startEvent_));
26
+ CUDA_VERIFY(cudaEventDestroy(stopEvent_));
30
27
  }
31
28
 
32
- float
33
- KernelTimer::elapsedMilliseconds() {
34
- FAISS_ASSERT(valid_);
29
+ float KernelTimer::elapsedMilliseconds() {
30
+ FAISS_ASSERT(valid_);
35
31
 
36
- CUDA_VERIFY(cudaEventRecord(stopEvent_, stream_));
37
- CUDA_VERIFY(cudaEventSynchronize(stopEvent_));
32
+ CUDA_VERIFY(cudaEventRecord(stopEvent_, stream_));
33
+ CUDA_VERIFY(cudaEventSynchronize(stopEvent_));
38
34
 
39
- auto time = 0.0f;
40
- CUDA_VERIFY(cudaEventElapsedTime(&time, startEvent_, stopEvent_));
41
- valid_ = false;
35
+ auto time = 0.0f;
36
+ CUDA_VERIFY(cudaEventElapsedTime(&time, startEvent_, stopEvent_));
37
+ valid_ = false;
42
38
 
43
- return time;
39
+ return time;
44
40
  }
45
41
 
46
42
  CpuTimer::CpuTimer() {
47
- start_ = std::chrono::steady_clock::now();
43
+ start_ = std::chrono::steady_clock::now();
48
44
  }
49
45
 
50
- float
51
- CpuTimer::elapsedMilliseconds() {
52
- auto end = std::chrono::steady_clock::now();
46
+ float CpuTimer::elapsedMilliseconds() {
47
+ auto end = std::chrono::steady_clock::now();
53
48
 
54
- std::chrono::duration<float, std::milli> duration = end - start_;
49
+ std::chrono::duration<float, std::milli> duration = end - start_;
55
50
 
56
- return duration.count();
51
+ return duration.count();
57
52
  }
58
53
 
59
- } } // namespace
54
+ } // namespace gpu
55
+ } // namespace faiss
@@ -5,48 +5,49 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <cuda_runtime.h>
12
11
  #include <chrono>
13
12
 
14
- namespace faiss { namespace gpu {
13
+ namespace faiss {
14
+ namespace gpu {
15
15
 
16
16
  /// Utility class for timing execution of a kernel
17
17
  class KernelTimer {
18
- public:
19
- /// Constructor starts the timer and adds an event into the current
20
- /// device stream
21
- KernelTimer(cudaStream_t stream = 0);
22
-
23
- /// Destructor releases event resources
24
- ~KernelTimer();
25
-
26
- /// Adds a stop event then synchronizes on the stop event to get the
27
- /// actual GPU-side kernel timings for any kernels launched in the
28
- /// current stream. Returns the number of milliseconds elapsed.
29
- /// Can only be called once.
30
- float elapsedMilliseconds();
31
-
32
- private:
33
- cudaEvent_t startEvent_;
34
- cudaEvent_t stopEvent_;
35
- cudaStream_t stream_;
36
- bool valid_;
18
+ public:
19
+ /// Constructor starts the timer and adds an event into the current
20
+ /// device stream
21
+ KernelTimer(cudaStream_t stream = 0);
22
+
23
+ /// Destructor releases event resources
24
+ ~KernelTimer();
25
+
26
+ /// Adds a stop event then synchronizes on the stop event to get the
27
+ /// actual GPU-side kernel timings for any kernels launched in the
28
+ /// current stream. Returns the number of milliseconds elapsed.
29
+ /// Can only be called once.
30
+ float elapsedMilliseconds();
31
+
32
+ private:
33
+ cudaEvent_t startEvent_;
34
+ cudaEvent_t stopEvent_;
35
+ cudaStream_t stream_;
36
+ bool valid_;
37
37
  };
38
38
 
39
39
  /// CPU wallclock elapsed timer
40
40
  class CpuTimer {
41
- public:
42
- /// Creates and starts a new timer
43
- CpuTimer();
41
+ public:
42
+ /// Creates and starts a new timer
43
+ CpuTimer();
44
44
 
45
- /// Returns elapsed time in milliseconds
46
- float elapsedMilliseconds();
45
+ /// Returns elapsed time in milliseconds
46
+ float elapsedMilliseconds();
47
47
 
48
- private:
49
- std::chrono::time_point<std::chrono::steady_clock> start_;
48
+ private:
49
+ std::chrono::time_point<std::chrono::steady_clock> start_;
50
50
  };
51
51
 
52
- } } // namespace
52
+ } // namespace gpu
53
+ } // namespace faiss
@@ -0,0 +1,503 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/AdditiveQuantizer.h>
11
+
12
+ #include <cstddef>
13
+ #include <cstdio>
14
+ #include <cstring>
15
+ #include <memory>
16
+ #include <random>
17
+
18
+ #include <algorithm>
19
+
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/utils/Heap.h>
22
+ #include <faiss/utils/distances.h>
23
+ #include <faiss/utils/hamming.h>
24
+ #include <faiss/utils/utils.h>
25
+
26
+ extern "C" {
27
+
28
+ // general matrix multiplication
29
+ int sgemm_(
30
+ const char* transa,
31
+ const char* transb,
32
+ FINTEGER* m,
33
+ FINTEGER* n,
34
+ FINTEGER* k,
35
+ const float* alpha,
36
+ const float* a,
37
+ FINTEGER* lda,
38
+ const float* b,
39
+ FINTEGER* ldb,
40
+ float* beta,
41
+ float* c,
42
+ FINTEGER* ldc);
43
+ }
44
+
45
+ namespace faiss {
46
+
47
+ AdditiveQuantizer::AdditiveQuantizer(
48
+ size_t d,
49
+ const std::vector<size_t>& nbits,
50
+ Search_type_t search_type)
51
+ : d(d),
52
+ M(nbits.size()),
53
+ nbits(nbits),
54
+ verbose(false),
55
+ is_trained(false),
56
+ search_type(search_type) {
57
+ norm_max = norm_min = NAN;
58
+ code_size = 0;
59
+ tot_bits = 0;
60
+ total_codebook_size = 0;
61
+ only_8bit = false;
62
+ set_derived_values();
63
+ }
64
+
65
+ AdditiveQuantizer::AdditiveQuantizer()
66
+ : AdditiveQuantizer(0, std::vector<size_t>()) {}
67
+
68
+ void AdditiveQuantizer::set_derived_values() {
69
+ tot_bits = 0;
70
+ only_8bit = true;
71
+ codebook_offsets.resize(M + 1, 0);
72
+ for (int i = 0; i < M; i++) {
73
+ int nbit = nbits[i];
74
+ size_t k = 1 << nbit;
75
+ codebook_offsets[i + 1] = codebook_offsets[i] + k;
76
+ tot_bits += nbit;
77
+ if (nbit != 0) {
78
+ only_8bit = false;
79
+ }
80
+ }
81
+ total_codebook_size = codebook_offsets[M];
82
+ switch (search_type) {
83
+ case ST_decompress:
84
+ case ST_LUT_nonorm:
85
+ case ST_norm_from_LUT:
86
+ break; // nothing to add
87
+ case ST_norm_float:
88
+ tot_bits += 32;
89
+ break;
90
+ case ST_norm_qint8:
91
+ case ST_norm_cqint8:
92
+ tot_bits += 8;
93
+ break;
94
+ case ST_norm_qint4:
95
+ case ST_norm_cqint4:
96
+ tot_bits += 4;
97
+ break;
98
+ }
99
+
100
+ // convert bits to bytes
101
+ code_size = (tot_bits + 7) / 8;
102
+ }
103
+
104
+ namespace {
105
+
106
+ // TODO
107
+ // https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0
108
+
109
+ uint8_t encode_qint8(float x, float amin, float amax) {
110
+ float x1 = (x - amin) / (amax - amin) * 256;
111
+ int32_t xi = int32_t(floor(x1));
112
+
113
+ return xi < 0 ? 0 : xi > 255 ? 255 : xi;
114
+ }
115
+
116
+ uint8_t encode_qint4(float x, float amin, float amax) {
117
+ float x1 = (x - amin) / (amax - amin) * 16;
118
+ int32_t xi = int32_t(floor(x1));
119
+
120
+ return xi < 0 ? 0 : xi > 15 ? 15 : xi;
121
+ }
122
+
123
+ float decode_qint8(uint8_t i, float amin, float amax) {
124
+ return (i + 0.5) / 256 * (amax - amin) + amin;
125
+ }
126
+
127
+ float decode_qint4(uint8_t i, float amin, float amax) {
128
+ return (i + 0.5) / 16 * (amax - amin) + amin;
129
+ }
130
+
131
+ } // anonymous namespace
132
+
133
+ uint32_t AdditiveQuantizer::encode_qcint(float x) const {
134
+ idx_t id;
135
+ qnorm.assign(idx_t(1), &x, &id, idx_t(1));
136
+ return uint32_t(id);
137
+ }
138
+
139
+ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
140
+ return qnorm.get_xb()[c];
141
+ }
142
+
143
+ void AdditiveQuantizer::pack_codes(
144
+ size_t n,
145
+ const int32_t* codes,
146
+ uint8_t* packed_codes,
147
+ int64_t ld_codes,
148
+ const float* norms) const {
149
+ if (ld_codes == -1) {
150
+ ld_codes = M;
151
+ }
152
+ std::vector<float> norm_buf;
153
+ if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
154
+ search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
155
+ search_type == ST_norm_cqint4) {
156
+ if (!norms) {
157
+ norm_buf.resize(n);
158
+ std::vector<float> x_recons(n * d);
159
+ decode_unpacked(codes, x_recons.data(), n, ld_codes);
160
+ fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
161
+ norms = norm_buf.data();
162
+ }
163
+ }
164
+ #pragma omp parallel for if (n > 1000)
165
+ for (int64_t i = 0; i < n; i++) {
166
+ const int32_t* codes1 = codes + i * ld_codes;
167
+ BitstringWriter bsw(packed_codes + i * code_size, code_size);
168
+ for (int m = 0; m < M; m++) {
169
+ bsw.write(codes1[m], nbits[m]);
170
+ }
171
+ switch (search_type) {
172
+ case ST_decompress:
173
+ case ST_LUT_nonorm:
174
+ case ST_norm_from_LUT:
175
+ break;
176
+ case ST_norm_float:
177
+ bsw.write(*(uint32_t*)&norms[i], 32);
178
+ break;
179
+ case ST_norm_qint8: {
180
+ uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
181
+ bsw.write(b, 8);
182
+ break;
183
+ }
184
+ case ST_norm_qint4: {
185
+ uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
186
+ bsw.write(b, 4);
187
+ break;
188
+ }
189
+ case ST_norm_cqint8: {
190
+ uint32_t b = encode_qcint(norms[i]);
191
+ bsw.write(b, 8);
192
+ break;
193
+ }
194
+ case ST_norm_cqint4: {
195
+ uint32_t b = encode_qcint(norms[i]);
196
+ bsw.write(b, 4);
197
+ break;
198
+ }
199
+ }
200
+ }
201
+ }
202
+
203
+ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
204
+ FAISS_THROW_IF_NOT_MSG(
205
+ is_trained, "The additive quantizer is not trained yet.");
206
+
207
+ // standard additive quantizer decoding
208
+ #pragma omp parallel for if (n > 1000)
209
+ for (int64_t i = 0; i < n; i++) {
210
+ BitstringReader bsr(code + i * code_size, code_size);
211
+ float* xi = x + i * d;
212
+ for (int m = 0; m < M; m++) {
213
+ int idx = bsr.read(nbits[m]);
214
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
215
+ if (m == 0) {
216
+ memcpy(xi, c, sizeof(*x) * d);
217
+ } else {
218
+ fvec_add(d, xi, c, xi);
219
+ }
220
+ }
221
+ }
222
+ }
223
+
224
+ void AdditiveQuantizer::decode_unpacked(
225
+ const int32_t* code,
226
+ float* x,
227
+ size_t n,
228
+ int64_t ld_codes) const {
229
+ FAISS_THROW_IF_NOT_MSG(
230
+ is_trained, "The additive quantizer is not trained yet.");
231
+
232
+ if (ld_codes == -1) {
233
+ ld_codes = M;
234
+ }
235
+
236
+ // standard additive quantizer decoding
237
+ #pragma omp parallel for if (n > 1000)
238
+ for (int64_t i = 0; i < n; i++) {
239
+ const int32_t* codesi = code + i * ld_codes;
240
+ float* xi = x + i * d;
241
+ for (int m = 0; m < M; m++) {
242
+ int idx = codesi[m];
243
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
244
+ if (m == 0) {
245
+ memcpy(xi, c, sizeof(*x) * d);
246
+ } else {
247
+ fvec_add(d, xi, c, xi);
248
+ }
249
+ }
250
+ }
251
+ }
252
+
253
+ AdditiveQuantizer::~AdditiveQuantizer() {}
254
+
255
+ /****************************************************************************
256
+ * Support for fast distance computations in centroids
257
+ ****************************************************************************/
258
+
259
+ void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
260
+ size_t ntotal = (size_t)1 << tot_bits;
261
+ // TODO: make tree of partial sums
262
+ #pragma omp parallel
263
+ {
264
+ std::vector<float> tmp(d);
265
+ #pragma omp for
266
+ for (int64_t i = 0; i < ntotal; i++) {
267
+ decode_64bit(i, tmp.data());
268
+ norms[i] = fvec_norm_L2sqr(tmp.data(), d);
269
+ }
270
+ }
271
+ }
272
+
273
+ void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
274
+ for (int m = 0; m < M; m++) {
275
+ idx_t idx = bits & (((size_t)1 << nbits[m]) - 1);
276
+ bits >>= nbits[m];
277
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
278
+ if (m == 0) {
279
+ memcpy(xi, c, sizeof(*xi) * d);
280
+ } else {
281
+ fvec_add(d, xi, c, xi);
282
+ }
283
+ }
284
+ }
285
+
286
+ void AdditiveQuantizer::compute_LUT(size_t n, const float* xq, float* LUT)
287
+ const {
288
+ // in all cases, it is large matrix multiplication
289
+
290
+ FINTEGER ncenti = total_codebook_size;
291
+ FINTEGER di = d;
292
+ FINTEGER nqi = n;
293
+ float one = 1, zero = 0;
294
+
295
+ sgemm_("Transposed",
296
+ "Not transposed",
297
+ &ncenti,
298
+ &nqi,
299
+ &di,
300
+ &one,
301
+ codebooks.data(),
302
+ &di,
303
+ xq,
304
+ &di,
305
+ &zero,
306
+ LUT,
307
+ &ncenti);
308
+ }
309
+
310
+ namespace {
311
+
312
+ void compute_inner_prod_with_LUT(
313
+ const AdditiveQuantizer& aq,
314
+ const float* LUT,
315
+ float* ips) {
316
+ size_t prev_size = 1;
317
+ for (int m = 0; m < aq.M; m++) {
318
+ const float* LUTm = LUT + aq.codebook_offsets[m];
319
+ int nb = aq.nbits[m];
320
+ size_t nc = (size_t)1 << nb;
321
+
322
+ if (m == 0) {
323
+ memcpy(ips, LUT, sizeof(*ips) * nc);
324
+ } else {
325
+ for (int64_t i = nc - 1; i >= 0; i--) {
326
+ float v = LUTm[i];
327
+ fvec_add(prev_size, ips, v, ips + i * prev_size);
328
+ }
329
+ }
330
+ prev_size *= nc;
331
+ }
332
+ }
333
+
334
+ } // anonymous namespace
335
+
336
+ void AdditiveQuantizer::knn_centroids_inner_product(
337
+ idx_t n,
338
+ const float* xq,
339
+ idx_t k,
340
+ float* distances,
341
+ idx_t* labels) const {
342
+ std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]);
343
+ compute_LUT(n, xq, LUT.get());
344
+ size_t ntotal = (size_t)1 << tot_bits;
345
+
346
+ #pragma omp parallel if (n > 100)
347
+ {
348
+ std::vector<float> dis(ntotal);
349
+ #pragma omp for
350
+ for (idx_t i = 0; i < n; i++) {
351
+ const float* LUTi = LUT.get() + i * total_codebook_size;
352
+ compute_inner_prod_with_LUT(*this, LUTi, dis.data());
353
+ float* distances_i = distances + i * k;
354
+ idx_t* labels_i = labels + i * k;
355
+ minheap_heapify(k, distances_i, labels_i);
356
+ minheap_addn(k, distances_i, labels_i, dis.data(), nullptr, ntotal);
357
+ minheap_reorder(k, distances_i, labels_i);
358
+ }
359
+ }
360
+ }
361
+
362
+ void AdditiveQuantizer::knn_centroids_L2(
363
+ idx_t n,
364
+ const float* xq,
365
+ idx_t k,
366
+ float* distances,
367
+ idx_t* labels,
368
+ const float* norms) const {
369
+ std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]);
370
+ compute_LUT(n, xq, LUT.get());
371
+ std::unique_ptr<float[]> q_norms(new float[n]);
372
+ fvec_norms_L2sqr(q_norms.get(), xq, d, n);
373
+ size_t ntotal = (size_t)1 << tot_bits;
374
+
375
+ #pragma omp parallel if (n > 100)
376
+ {
377
+ std::vector<float> dis(ntotal);
378
+ #pragma omp for
379
+ for (idx_t i = 0; i < n; i++) {
380
+ const float* LUTi = LUT.get() + i * total_codebook_size;
381
+ float* distances_i = distances + i * k;
382
+ idx_t* labels_i = labels + i * k;
383
+
384
+ compute_inner_prod_with_LUT(*this, LUTi, dis.data());
385
+
386
+ // update distances using
387
+ // ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * <x,y>
388
+
389
+ maxheap_heapify(k, distances_i, labels_i);
390
+ for (idx_t j = 0; j < ntotal; j++) {
391
+ float disj = q_norms[i] + norms[j] - 2 * dis[j];
392
+ if (disj < distances_i[0]) {
393
+ heap_replace_top<CMax<float, int64_t>>(
394
+ k, distances_i, labels_i, disj, j);
395
+ }
396
+ }
397
+ maxheap_reorder(k, distances_i, labels_i);
398
+ }
399
+ }
400
+ }
401
+
402
+ /****************************************************************************
403
+ * Support for fast distance computations in codes
404
+ ****************************************************************************/
405
+
406
+ namespace {
407
+
408
+ float accumulate_IPs(
409
+ const AdditiveQuantizer& aq,
410
+ BitstringReader& bs,
411
+ const uint8_t* codes,
412
+ const float* LUT) {
413
+ float accu = 0;
414
+ for (int m = 0; m < aq.M; m++) {
415
+ size_t nbit = aq.nbits[m];
416
+ int idx = bs.read(nbit);
417
+ accu += LUT[idx];
418
+ LUT += (uint64_t)1 << nbit;
419
+ }
420
+ return accu;
421
+ }
422
+
423
+ } // anonymous namespace
424
+
425
+ template <>
426
+ float AdditiveQuantizer::
427
+ compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
428
+ const uint8_t* codes,
429
+ const float* LUT) const {
430
+ BitstringReader bs(codes, code_size);
431
+ return accumulate_IPs(*this, bs, codes, LUT);
432
+ }
433
+
434
+ template <>
435
+ float AdditiveQuantizer::
436
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>(
437
+ const uint8_t* codes,
438
+ const float* LUT) const {
439
+ BitstringReader bs(codes, code_size);
440
+ return -accumulate_IPs(*this, bs, codes, LUT);
441
+ }
442
+
443
+ template <>
444
+ float AdditiveQuantizer::
445
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>(
446
+ const uint8_t* codes,
447
+ const float* LUT) const {
448
+ BitstringReader bs(codes, code_size);
449
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
450
+ uint32_t norm_i = bs.read(32);
451
+ float norm2 = *(float*)&norm_i;
452
+ return norm2 - 2 * accu;
453
+ }
454
+
455
+ template <>
456
+ float AdditiveQuantizer::
457
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
458
+ const uint8_t* codes,
459
+ const float* LUT) const {
460
+ BitstringReader bs(codes, code_size);
461
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
462
+ uint32_t norm_i = bs.read(8);
463
+ float norm2 = decode_qcint(norm_i);
464
+ return norm2 - 2 * accu;
465
+ }
466
+
467
+ template <>
468
+ float AdditiveQuantizer::
469
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>(
470
+ const uint8_t* codes,
471
+ const float* LUT) const {
472
+ BitstringReader bs(codes, code_size);
473
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
474
+ uint32_t norm_i = bs.read(4);
475
+ float norm2 = decode_qcint(norm_i);
476
+ return norm2 - 2 * accu;
477
+ }
478
+
479
+ template <>
480
+ float AdditiveQuantizer::
481
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>(
482
+ const uint8_t* codes,
483
+ const float* LUT) const {
484
+ BitstringReader bs(codes, code_size);
485
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
486
+ uint32_t norm_i = bs.read(8);
487
+ float norm2 = decode_qint8(norm_i, norm_min, norm_max);
488
+ return norm2 - 2 * accu;
489
+ }
490
+
491
+ template <>
492
+ float AdditiveQuantizer::
493
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>(
494
+ const uint8_t* codes,
495
+ const float* LUT) const {
496
+ BitstringReader bs(codes, code_size);
497
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
498
+ uint32_t norm_i = bs.read(4);
499
+ float norm2 = decode_qint4(norm_i, norm_min, norm_max);
500
+ return norm2 - 2 * accu;
501
+ }
502
+
503
+ } // namespace faiss