faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,175 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstdint>
11
+ #include <vector>
12
+
13
+ #include <faiss/Index.h>
14
+ #include <faiss/IndexFlat.h>
15
+
16
+ namespace faiss {
17
+
18
+ /** Abstract structure for additive quantizers
19
+ *
20
+ * Different from the product quantizer in which the decoded vector is the
21
+ * concatenation of M sub-vectors, additive quantizers sum M sub-vectors
22
+ * to get the decoded vector.
23
+ */
24
+ struct AdditiveQuantizer {
25
+ size_t d; ///< size of the input vectors
26
+ size_t M; ///< number of codebooks
27
+ std::vector<size_t> nbits; ///< bits for each step
28
+ std::vector<float> codebooks; ///< codebooks
29
+
30
+ // derived values
31
+ std::vector<uint64_t> codebook_offsets;
32
+ size_t code_size; ///< code size in bytes
33
+ size_t tot_bits; ///< total number of bits
34
+ size_t total_codebook_size; ///< size of the codebook in vectors
35
+ bool only_8bit; ///< are all nbits = 8 (use faster decoder)
36
+
37
+ bool verbose; ///< verbose during training?
38
+ bool is_trained; ///< is trained or not
39
+
40
+ IndexFlat1D qnorm; ///< store and search norms
41
+
42
+ uint32_t encode_qcint(
43
+ float x) const; ///< encode norm by non-uniform scalar quantization
44
+
45
+ float decode_qcint(uint32_t c)
46
+ const; ///< decode norm by non-uniform scalar quantization
47
+
48
+ /// Encodes how search is performed and how vectors are encoded
49
+ enum Search_type_t {
50
+ ST_decompress, ///< decompress database vector
51
+ ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
52
+ ///< normalized vectors)
53
+ ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
54
+ ///< is in O(M^2))
55
+ ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
56
+ ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
57
+ ST_norm_qint4,
58
+ ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
59
+ ST_norm_cqint4,
60
+ };
61
+
62
+ AdditiveQuantizer(
63
+ size_t d,
64
+ const std::vector<size_t>& nbits,
65
+ Search_type_t search_type = ST_decompress);
66
+
67
+ AdditiveQuantizer();
68
+
69
+ ///< compute derived values when d, M and nbits have been set
70
+ void set_derived_values();
71
+
72
+ ///< Train the additive quantizer
73
+ virtual void train(size_t n, const float* x) = 0;
74
+
75
+ /** Encode a set of vectors
76
+ *
77
+ * @param x vectors to encode, size n * d
78
+ * @param codes output codes, size n * code_size
79
+ */
80
+ virtual void compute_codes(const float* x, uint8_t* codes, size_t n)
81
+ const = 0;
82
+
83
+ /** pack a series of code to bit-compact format
84
+ *
85
+ * @param codes codes to be packed, size n * code_size
86
+ * @param packed_codes output bit-compact codes
87
+ * @param ld_codes leading dimension of codes
88
+ * @param norms norms of the vectors (size n). Will be computed if
89
+ * needed but not provided
90
+ */
91
+ void pack_codes(
92
+ size_t n,
93
+ const int32_t* codes,
94
+ uint8_t* packed_codes,
95
+ int64_t ld_codes = -1,
96
+ const float* norms = nullptr) const;
97
+
98
+ /** Decode a set of vectors
99
+ *
100
+ * @param codes codes to decode, size n * code_size
101
+ * @param x output vectors, size n * d
102
+ */
103
+ void decode(const uint8_t* codes, float* x, size_t n) const;
104
+
105
+ /** Decode a set of vectors in non-packed format
106
+ *
107
+ * @param codes codes to decode, size n * ld_codes
108
+ * @param x output vectors, size n * d
109
+ */
110
+ void decode_unpacked(
111
+ const int32_t* codes,
112
+ float* x,
113
+ size_t n,
114
+ int64_t ld_codes = -1) const;
115
+
116
+ /****************************************************************************
117
+ * Search functions in an external set of codes.
118
+ ****************************************************************************/
119
+
120
+ /// Also determines what's in the codes
121
+ Search_type_t search_type;
122
+
123
+ /// min/max for quantization of norms
124
+ float norm_min, norm_max;
125
+
126
+ template <bool is_IP, Search_type_t effective_search_type>
127
+ float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
128
+
129
+ /*
130
+ float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
131
+ */
132
+ /****************************************************************************
133
+ * Support for exhaustive distance computations with all the centroids.
134
+ * Hence, the number of these centroids should not be too large.
135
+ ****************************************************************************/
136
+ using idx_t = Index::idx_t;
137
+
138
+ /// decoding function for a code in a 64-bit word
139
+ void decode_64bit(idx_t n, float* x) const;
140
+
141
+ /** Compute inner-product look-up tables. Used in the centroid search
142
+ * functions.
143
+ *
144
+ * @param xq query vector, size (n, d)
145
+ * @param LUT look-up table, size (n, total_codebook_size)
146
+ */
147
+ void compute_LUT(size_t n, const float* xq, float* LUT) const;
148
+
149
+ /// exact IP search
150
+ void knn_centroids_inner_product(
151
+ idx_t n,
152
+ const float* xq,
153
+ idx_t k,
154
+ float* distances,
155
+ idx_t* labels) const;
156
+
157
+ /** For L2 search we need the L2 norms of the centroids
158
+ *
159
+ * @param norms output norms table, size total_codebook_size
160
+ */
161
+ void compute_centroid_norms(float* norms) const;
162
+
163
+ /** Exact L2 search, with precomputed norms */
164
+ void knn_centroids_L2(
165
+ idx_t n,
166
+ const float* xq,
167
+ idx_t k,
168
+ float* distances,
169
+ idx_t* labels,
170
+ const float* centroid_norms) const;
171
+
172
+ virtual ~AdditiveQuantizer();
173
+ };
174
+
175
+ }; // namespace faiss
@@ -14,18 +14,16 @@
14
14
 
15
15
  #include <faiss/impl/FaissAssert.h>
16
16
 
17
-
18
17
  namespace faiss {
19
18
 
20
-
21
19
  /***********************************************************************
22
20
  * RangeSearchResult
23
21
  ***********************************************************************/
24
22
 
25
- RangeSearchResult::RangeSearchResult (idx_t nq, bool alloc_lims): nq (nq) {
23
+ RangeSearchResult::RangeSearchResult(idx_t nq, bool alloc_lims) : nq(nq) {
26
24
  if (alloc_lims) {
27
- lims = new size_t [nq + 1];
28
- memset (lims, 0, sizeof(*lims) * (nq + 1));
25
+ lims = new size_t[nq + 1];
26
+ memset(lims, 0, sizeof(*lims) * (nq + 1));
29
27
  } else {
30
28
  lims = nullptr;
31
29
  }
@@ -36,145 +34,129 @@ RangeSearchResult::RangeSearchResult (idx_t nq, bool alloc_lims): nq (nq) {
36
34
 
37
35
  /// called when lims contains the nb of elements result entries
38
36
  /// for each query
39
- void RangeSearchResult::do_allocation () {
37
+ void RangeSearchResult::do_allocation() {
38
+ // works only if all the partial results are aggregated
39
+ // simulatenously
40
+ FAISS_THROW_IF_NOT(labels == nullptr && distances == nullptr);
40
41
  size_t ofs = 0;
41
42
  for (int i = 0; i < nq; i++) {
42
43
  size_t n = lims[i];
43
- lims [i] = ofs;
44
+ lims[i] = ofs;
44
45
  ofs += n;
45
46
  }
46
- lims [nq] = ofs;
47
- labels = new idx_t [ofs];
48
- distances = new float [ofs];
47
+ lims[nq] = ofs;
48
+ labels = new idx_t[ofs];
49
+ distances = new float[ofs];
49
50
  }
50
51
 
51
- RangeSearchResult::~RangeSearchResult () {
52
- delete [] labels;
53
- delete [] distances;
54
- delete [] lims;
52
+ RangeSearchResult::~RangeSearchResult() {
53
+ delete[] labels;
54
+ delete[] distances;
55
+ delete[] lims;
55
56
  }
56
57
 
57
-
58
-
59
-
60
-
61
58
  /***********************************************************************
62
59
  * BufferList
63
60
  ***********************************************************************/
64
61
 
65
-
66
- BufferList::BufferList (size_t buffer_size):
67
- buffer_size (buffer_size)
68
- {
62
+ BufferList::BufferList(size_t buffer_size) : buffer_size(buffer_size) {
69
63
  wp = buffer_size;
70
64
  }
71
65
 
72
- BufferList::~BufferList ()
73
- {
66
+ BufferList::~BufferList() {
74
67
  for (int i = 0; i < buffers.size(); i++) {
75
- delete [] buffers[i].ids;
76
- delete [] buffers[i].dis;
68
+ delete[] buffers[i].ids;
69
+ delete[] buffers[i].dis;
77
70
  }
78
71
  }
79
72
 
80
- void BufferList::add (idx_t id, float dis) {
73
+ void BufferList::add(idx_t id, float dis) {
81
74
  if (wp == buffer_size) { // need new buffer
82
75
  append_buffer();
83
76
  }
84
- Buffer & buf = buffers.back();
85
- buf.ids [wp] = id;
86
- buf.dis [wp] = dis;
77
+ Buffer& buf = buffers.back();
78
+ buf.ids[wp] = id;
79
+ buf.dis[wp] = dis;
87
80
  wp++;
88
81
  }
89
82
 
90
-
91
- void BufferList::append_buffer ()
92
- {
93
- Buffer buf = {new idx_t [buffer_size], new float [buffer_size]};
94
- buffers.push_back (buf);
83
+ void BufferList::append_buffer() {
84
+ Buffer buf = {new idx_t[buffer_size], new float[buffer_size]};
85
+ buffers.push_back(buf);
95
86
  wp = 0;
96
87
  }
97
88
 
98
89
  /// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
99
90
  /// tables dest_ids, dest_dis
100
- void BufferList::copy_range (size_t ofs, size_t n,
101
- idx_t * dest_ids, float *dest_dis)
102
- {
91
+ void BufferList::copy_range(
92
+ size_t ofs,
93
+ size_t n,
94
+ idx_t* dest_ids,
95
+ float* dest_dis) {
103
96
  size_t bno = ofs / buffer_size;
104
97
  ofs -= bno * buffer_size;
105
98
  while (n > 0) {
106
99
  size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs;
107
- Buffer buf = buffers [bno];
108
- memcpy (dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
109
- memcpy (dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
100
+ Buffer buf = buffers[bno];
101
+ memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
102
+ memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
110
103
  dest_ids += ncopy;
111
104
  dest_dis += ncopy;
112
105
  ofs = 0;
113
- bno ++;
106
+ bno++;
114
107
  n -= ncopy;
115
108
  }
116
109
  }
117
110
 
118
-
119
111
  /***********************************************************************
120
112
  * RangeSearchPartialResult
121
113
  ***********************************************************************/
122
114
 
123
- void RangeQueryResult::add (float dis, idx_t id) {
115
+ void RangeQueryResult::add(float dis, idx_t id) {
124
116
  nres++;
125
- pres->add (id, dis);
117
+ pres->add(id, dis);
126
118
  }
127
119
 
128
-
129
-
130
- RangeSearchPartialResult::RangeSearchPartialResult (RangeSearchResult * res_in):
131
- BufferList(res_in->buffer_size),
132
- res(res_in)
133
- {}
134
-
120
+ RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult* res_in)
121
+ : BufferList(res_in->buffer_size), res(res_in) {}
135
122
 
136
123
  /// begin a new result
137
- RangeQueryResult &
138
- RangeSearchPartialResult::new_result (idx_t qno)
139
- {
124
+ RangeQueryResult& RangeSearchPartialResult::new_result(idx_t qno) {
140
125
  RangeQueryResult qres = {qno, 0, this};
141
- queries.push_back (qres);
126
+ queries.push_back(qres);
142
127
  return queries.back();
143
128
  }
144
129
 
145
-
146
- void RangeSearchPartialResult::finalize ()
147
- {
148
- set_lims ();
130
+ void RangeSearchPartialResult::finalize() {
131
+ set_lims();
149
132
  #pragma omp barrier
150
133
 
151
134
  #pragma omp single
152
- res->do_allocation ();
135
+ res->do_allocation();
153
136
 
154
137
  #pragma omp barrier
155
- copy_result ();
138
+ copy_result();
156
139
  }
157
140
 
158
-
159
141
  /// called by range_search before do_allocation
160
- void RangeSearchPartialResult::set_lims ()
161
- {
142
+ void RangeSearchPartialResult::set_lims() {
162
143
  for (int i = 0; i < queries.size(); i++) {
163
- RangeQueryResult & qres = queries[i];
144
+ RangeQueryResult& qres = queries[i];
164
145
  res->lims[qres.qno] = qres.nres;
165
146
  }
166
147
  }
167
148
 
168
149
  /// called by range_search after do_allocation
169
- void RangeSearchPartialResult::copy_result (bool incremental)
170
- {
150
+ void RangeSearchPartialResult::copy_result(bool incremental) {
171
151
  size_t ofs = 0;
172
152
  for (int i = 0; i < queries.size(); i++) {
173
- RangeQueryResult & qres = queries[i];
153
+ RangeQueryResult& qres = queries[i];
174
154
 
175
- copy_range (ofs, qres.nres,
176
- res->labels + res->lims[qres.qno],
177
- res->distances + res->lims[qres.qno]);
155
+ copy_range(
156
+ ofs,
157
+ qres.nres,
158
+ res->labels + res->lims[qres.qno],
159
+ res->distances + res->lims[qres.qno]);
178
160
  if (incremental) {
179
161
  res->lims[qres.qno] += qres.nres;
180
162
  }
@@ -182,26 +164,28 @@ void RangeSearchPartialResult::copy_result (bool incremental)
182
164
  }
183
165
  }
184
166
 
185
- void RangeSearchPartialResult::merge (std::vector <RangeSearchPartialResult *> &
186
- partial_results, bool do_delete)
187
- {
188
-
167
+ void RangeSearchPartialResult::merge(
168
+ std::vector<RangeSearchPartialResult*>& partial_results,
169
+ bool do_delete) {
189
170
  int npres = partial_results.size();
190
- if (npres == 0) return;
191
- RangeSearchResult *result = partial_results[0]->res;
171
+ if (npres == 0)
172
+ return;
173
+ RangeSearchResult* result = partial_results[0]->res;
192
174
  size_t nx = result->nq;
193
175
 
194
176
  // count
195
- for (const RangeSearchPartialResult * pres : partial_results) {
196
- if (!pres) continue;
197
- for (const RangeQueryResult &qres : pres->queries) {
177
+ for (const RangeSearchPartialResult* pres : partial_results) {
178
+ if (!pres)
179
+ continue;
180
+ for (const RangeQueryResult& qres : pres->queries) {
198
181
  result->lims[qres.qno] += qres.nres;
199
182
  }
200
183
  }
201
- result->do_allocation ();
184
+ result->do_allocation();
202
185
  for (int j = 0; j < npres; j++) {
203
- if (!partial_results[j]) continue;
204
- partial_results[j]->copy_result (true);
186
+ if (!partial_results[j])
187
+ continue;
188
+ partial_results[j]->copy_result(true);
205
189
  if (do_delete) {
206
190
  delete partial_results[j];
207
191
  partial_results[j] = nullptr;
@@ -210,22 +194,19 @@ void RangeSearchPartialResult::merge (std::vector <RangeSearchPartialResult *> &
210
194
 
211
195
  // reset the limits
212
196
  for (size_t i = nx; i > 0; i--) {
213
- result->lims [i] = result->lims [i - 1];
197
+ result->lims[i] = result->lims[i - 1];
214
198
  }
215
- result->lims [0] = 0;
199
+ result->lims[0] = 0;
216
200
  }
217
201
 
218
202
  /***********************************************************************
219
203
  * IDSelectorRange
220
204
  ***********************************************************************/
221
205
 
222
- IDSelectorRange::IDSelectorRange (idx_t imin, idx_t imax):
223
- imin (imin), imax (imax)
224
- {
225
- }
206
+ IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax)
207
+ : imin(imin), imax(imax) {}
226
208
 
227
- bool IDSelectorRange::is_member (idx_t id) const
228
- {
209
+ bool IDSelectorRange::is_member(idx_t id) const {
229
210
  return id >= imin && id < imax;
230
211
  }
231
212
 
@@ -233,33 +214,29 @@ bool IDSelectorRange::is_member (idx_t id) const
233
214
  * IDSelectorArray
234
215
  ***********************************************************************/
235
216
 
236
- IDSelectorArray::IDSelectorArray (size_t n, const idx_t *ids):
237
- n (n), ids(ids)
238
- {
239
- }
217
+ IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
240
218
 
241
- bool IDSelectorArray::is_member (idx_t id) const
242
- {
219
+ bool IDSelectorArray::is_member(idx_t id) const {
243
220
  for (idx_t i = 0; i < n; i++) {
244
- if (ids[i] == id) return true;
221
+ if (ids[i] == id)
222
+ return true;
245
223
  }
246
224
  return false;
247
225
  }
248
226
 
249
-
250
227
  /***********************************************************************
251
228
  * IDSelectorBatch
252
229
  ***********************************************************************/
253
230
 
254
- IDSelectorBatch::IDSelectorBatch (size_t n, const idx_t *indices)
255
- {
231
+ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
256
232
  nbits = 0;
257
- while (n > (1L << nbits)) nbits++;
233
+ while (n > (1L << nbits))
234
+ nbits++;
258
235
  nbits += 5;
259
236
  // for n = 1M, nbits = 25 is optimal, see P56659518
260
237
 
261
238
  mask = (1L << nbits) - 1;
262
- bloom.resize (1UL << (nbits - 3), 0);
239
+ bloom.resize(1UL << (nbits - 3), 0);
263
240
  for (long i = 0; i < n; i++) {
264
241
  Index::idx_t id = indices[i];
265
242
  set.insert(id);
@@ -268,39 +245,36 @@ IDSelectorBatch::IDSelectorBatch (size_t n, const idx_t *indices)
268
245
  }
269
246
  }
270
247
 
271
- bool IDSelectorBatch::is_member (idx_t i) const
272
- {
248
+ bool IDSelectorBatch::is_member(idx_t i) const {
273
249
  long im = i & mask;
274
- if(!(bloom[im>>3] & (1 << (im & 7)))) {
250
+ if (!(bloom[im >> 3] & (1 << (im & 7)))) {
275
251
  return 0;
276
252
  }
277
253
  return set.count(i);
278
254
  }
279
255
 
280
-
281
256
  /***********************************************************
282
257
  * Interrupt callback
283
258
  ***********************************************************/
284
259
 
285
-
286
260
  std::unique_ptr<InterruptCallback> InterruptCallback::instance;
287
261
 
288
262
  std::mutex InterruptCallback::lock;
289
263
 
290
- void InterruptCallback::clear_instance () {
291
- delete instance.release ();
264
+ void InterruptCallback::clear_instance() {
265
+ delete instance.release();
292
266
  }
293
267
 
294
- void InterruptCallback::check () {
268
+ void InterruptCallback::check() {
295
269
  if (!instance.get()) {
296
270
  return;
297
271
  }
298
- if (instance->want_interrupt ()) {
299
- FAISS_THROW_MSG ("computation interrupted");
272
+ if (instance->want_interrupt()) {
273
+ FAISS_THROW_MSG("computation interrupted");
300
274
  }
301
275
  }
302
276
 
303
- bool InterruptCallback::is_interrupted () {
277
+ bool InterruptCallback::is_interrupted() {
304
278
  if (!instance.get()) {
305
279
  return false;
306
280
  }
@@ -308,8 +282,7 @@ bool InterruptCallback::is_interrupted () {
308
282
  return instance->want_interrupt();
309
283
  }
310
284
 
311
-
312
- size_t InterruptCallback::get_period_hint (size_t flops) {
285
+ size_t InterruptCallback::get_period_hint(size_t flops) {
313
286
  if (!instance.get()) {
314
287
  return 1L << 30; // never check
315
288
  }
@@ -317,7 +290,4 @@ size_t InterruptCallback::get_period_hint (size_t flops) {
317
290
  return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
318
291
  }
319
292
 
320
-
321
-
322
-
323
293
  } // namespace faiss