faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,188 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstdint>
11
+ #include <vector>
12
+
13
+ #include <faiss/Clustering.h>
14
+ #include <faiss/impl/AdditiveQuantizer.h>
15
+
16
+ namespace faiss {
17
+
18
+ /** Residual quantizer with variable number of bits per sub-quantizer
19
+ *
20
+ * The residual centroids are stored in a big cumulative centroid table.
21
+ * The codes are represented either as a non-compact table of size (n, M) or
22
+ * as the compact output (n, code_size).
23
+ */
24
+
25
+ struct ResidualQuantizer : AdditiveQuantizer {
26
+ /// initialization
27
+ enum train_type_t {
28
+ Train_default = 0, ///< regular k-means
29
+ Train_progressive_dim = 1, ///< progressive dim clustering
30
+ Train_default_Train_top_beam = 1024,
31
+ Train_progressive_dim_Train_top_beam = 1025,
32
+ Train_default_Skip_codebook_tables = 2048,
33
+ Train_progressive_dim_Skip_codebook_tables = 2049,
34
+ Train_default_Train_top_beam_Skip_codebook_tables = 3072,
35
+ Train_progressive_dim_Train_top_beam_Skip_codebook_tables = 3073,
36
+ };
37
+
38
+ train_type_t train_type;
39
+
40
+ // set this bit on train_type if beam is to be trained only on the
41
+ // first element of the beam (faster but less accurate)
42
+ static const int Train_top_beam = 1024;
43
+
44
+ // set this bit to not autmatically compute the codebook tables
45
+ // after training
46
+ static const int Skip_codebook_tables = 2048;
47
+
48
+ /// beam size used for training and for encoding
49
+ int max_beam_size;
50
+
51
+ /// use LUT for beam search
52
+ int use_beam_LUT;
53
+
54
+ /// distance matrixes with beam search can get large, so use this
55
+ /// to batch computations at encoding time.
56
+ size_t max_mem_distances;
57
+
58
+ /// clustering parameters
59
+ ProgressiveDimClusteringParameters cp;
60
+
61
+ /// if non-NULL, use this index for assignment
62
+ ProgressiveDimIndexFactory* assign_index_factory;
63
+
64
+ ResidualQuantizer(
65
+ size_t d,
66
+ const std::vector<size_t>& nbits,
67
+ Search_type_t search_type = ST_decompress);
68
+
69
+ ResidualQuantizer(
70
+ size_t d, /* dimensionality of the input vectors */
71
+ size_t M, /* number of subquantizers */
72
+ size_t nbits, /* number of bit per subvector index */
73
+ Search_type_t search_type = ST_decompress);
74
+
75
+ ResidualQuantizer();
76
+
77
+ // Train the residual quantizer
78
+ void train(size_t n, const float* x) override;
79
+
80
+ /** Encode a set of vectors
81
+ *
82
+ * @param x vectors to encode, size n * d
83
+ * @param codes output codes, size n * code_size
84
+ */
85
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
86
+
87
+ /** lower-level encode function
88
+ *
89
+ * @param n number of vectors to hanlde
90
+ * @param residuals vectors to encode, size (n, beam_size, d)
91
+ * @param beam_size input beam size
92
+ * @param new_beam_size output beam size (should be <= K * beam_size)
93
+ * @param new_codes output codes, size (n, new_beam_size, m + 1)
94
+ * @param new_residuals output residuals, size (n, new_beam_size, d)
95
+ * @param new_distances output distances, size (n, new_beam_size)
96
+ */
97
+ void refine_beam(
98
+ size_t n,
99
+ size_t beam_size,
100
+ const float* residuals,
101
+ int new_beam_size,
102
+ int32_t* new_codes,
103
+ float* new_residuals = nullptr,
104
+ float* new_distances = nullptr) const;
105
+
106
+ void refine_beam_LUT(
107
+ size_t n,
108
+ const float* query_norms,
109
+ const float* query_cp,
110
+ int new_beam_size,
111
+ int32_t* new_codes,
112
+ float* new_distances = nullptr) const;
113
+
114
+ /** Beam search can consume a lot of memory. This function estimates the
115
+ * amount of mem used by refine_beam to adjust the batch size
116
+ *
117
+ * @param beam_size if != -1, override the beam size
118
+ */
119
+ size_t memory_per_point(int beam_size = -1) const;
120
+
121
+ /** Cross products used in codebook tables
122
+ *
123
+ * These are used to keep trak of norms of centroids.
124
+ */
125
+ void compute_codebook_tables();
126
+
127
+ /// dot products of all codebook vectors with each other
128
+ /// size total_codebook_size * total_codebook_size
129
+ std::vector<float> codebook_cross_products;
130
+ /// norms of all vectors
131
+ std::vector<float> cent_norms;
132
+ };
133
+
134
+ /** Encode a residual by sampling from a centroid table.
135
+ *
136
+ * This is a single encoding step the residual quantizer.
137
+ * It allows low-level access to the encoding function, exposed mainly for unit
138
+ * tests.
139
+ *
140
+ * @param n number of vectors to hanlde
141
+ * @param residuals vectors to encode, size (n, beam_size, d)
142
+ * @param cent centroids, size (K, d)
143
+ * @param beam_size input beam size
144
+ * @param m size of the codes for the previous encoding steps
145
+ * @param codes code array for the previous steps of the beam (n,
146
+ * beam_size, m)
147
+ * @param new_beam_size output beam size (should be <= K * beam_size)
148
+ * @param new_codes output codes, size (n, new_beam_size, m + 1)
149
+ * @param new_residuals output residuals, size (n, new_beam_size, d)
150
+ * @param new_distances output distances, size (n, new_beam_size)
151
+ * @param assign_index if non-NULL, will be used to perform assignment
152
+ */
153
+ void beam_search_encode_step(
154
+ size_t d,
155
+ size_t K,
156
+ const float* cent,
157
+ size_t n,
158
+ size_t beam_size,
159
+ const float* residuals,
160
+ size_t m,
161
+ const int32_t* codes,
162
+ size_t new_beam_size,
163
+ int32_t* new_codes,
164
+ float* new_residuals,
165
+ float* new_distances,
166
+ Index* assign_index = nullptr);
167
+
168
+ /** Encode a set of vectors using their dot products with the codebooks
169
+ *
170
+ */
171
+ void beam_search_encode_step_tab(
172
+ size_t K,
173
+ size_t n,
174
+ size_t beam_size, // input sizes
175
+ const float* codebook_cross_norms, // size K * ldc
176
+ size_t ldc, // >= K
177
+ const uint64_t* codebook_offsets, // m
178
+ const float* query_cp, // size n * ldqc
179
+ size_t ldqc, // >= K
180
+ const float* cent_norms_i, // size K
181
+ size_t m,
182
+ const int32_t* codes, // n * beam_size * m
183
+ const float* distances, // n * beam_size
184
+ size_t new_beam_size,
185
+ int32_t* new_codes, // n * new_beam_size * (m + 1)
186
+ float* new_distances); // n * new_beam_size
187
+
188
+ }; // namespace faiss
@@ -5,49 +5,38 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  /*
10
9
  * Structures that collect search results from distance computations
11
10
  */
12
11
 
13
12
  #pragma once
14
13
 
15
-
14
+ #include <faiss/impl/AuxIndexStructures.h>
16
15
  #include <faiss/utils/Heap.h>
17
16
  #include <faiss/utils/partitioning.h>
18
- #include <faiss/impl/AuxIndexStructures.h>
19
-
20
17
 
21
18
  namespace faiss {
22
19
 
23
-
24
-
25
20
  /*****************************************************************
26
21
  * Heap based result handler
27
22
  *****************************************************************/
28
23
 
29
-
30
- template<class C>
24
+ template <class C>
31
25
  struct HeapResultHandler {
32
-
33
26
  using T = typename C::T;
34
27
  using TI = typename C::TI;
35
28
 
36
29
  int nq;
37
- T *heap_dis_tab;
38
- TI *heap_ids_tab;
30
+ T* heap_dis_tab;
31
+ TI* heap_ids_tab;
39
32
 
40
- int64_t k; // number of results to keep
33
+ int64_t k; // number of results to keep
41
34
 
42
- HeapResultHandler(
43
- size_t nq,
44
- T * heap_dis_tab, TI * heap_ids_tab,
45
- size_t k):
46
- nq(nq),
47
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
48
- {
49
-
50
- }
35
+ HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k)
36
+ : nq(nq),
37
+ heap_dis_tab(heap_dis_tab),
38
+ heap_ids_tab(heap_ids_tab),
39
+ k(k) {}
51
40
 
52
41
  /******************************************************
53
42
  * API for 1 result at a time (each SingleResultHandler is
@@ -55,20 +44,20 @@ struct HeapResultHandler {
55
44
  */
56
45
 
57
46
  struct SingleResultHandler {
58
- HeapResultHandler & hr;
47
+ HeapResultHandler& hr;
59
48
  size_t k;
60
49
 
61
- T *heap_dis;
62
- TI *heap_ids;
50
+ T* heap_dis;
51
+ TI* heap_ids;
63
52
  T thresh;
64
53
 
65
- SingleResultHandler(HeapResultHandler &hr): hr(hr), k(hr.k) {}
54
+ SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {}
66
55
 
67
56
  /// begin results for query # i
68
57
  void begin(size_t i) {
69
58
  heap_dis = hr.heap_dis_tab + i * k;
70
59
  heap_ids = hr.heap_ids_tab + i * k;
71
- heap_heapify<C> (k, heap_dis, heap_ids);
60
+ heap_heapify<C>(k, heap_dis, heap_ids);
72
61
  thresh = heap_dis[0];
73
62
  }
74
63
 
@@ -82,11 +71,10 @@ struct HeapResultHandler {
82
71
 
83
72
  /// series of results for query i is done
84
73
  void end() {
85
- heap_reorder<C> (k, heap_dis, heap_ids);
74
+ heap_reorder<C>(k, heap_dis, heap_ids);
86
75
  }
87
76
  };
88
77
 
89
-
90
78
  /******************************************************
91
79
  * API for multiple results (called from 1 thread)
92
80
  */
@@ -97,20 +85,21 @@ struct HeapResultHandler {
97
85
  void begin_multiple(size_t i0, size_t i1) {
98
86
  this->i0 = i0;
99
87
  this->i1 = i1;
100
- for(size_t i = i0; i < i1; i++) {
101
- heap_heapify<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
88
+ for (size_t i = i0; i < i1; i++) {
89
+ heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
102
90
  }
103
91
  }
104
92
 
105
93
  /// add results for query i0..i1 and j0..j1
106
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
107
- // maybe parallel for
108
- for (size_t i = i0; i < i1; i++) {
109
- T * heap_dis = heap_dis_tab + i * k;
110
- TI * heap_ids = heap_ids_tab + i * k;
94
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
95
+ #pragma omp parallel for
96
+ for (int64_t i = i0; i < i1; i++) {
97
+ T* heap_dis = heap_dis_tab + i * k;
98
+ TI* heap_ids = heap_ids_tab + i * k;
99
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
111
100
  T thresh = heap_dis[0];
112
101
  for (size_t j = j0; j < j1; j++) {
113
- T dis = *dis_tab++;
102
+ T dis = dis_tab_i[j];
114
103
  if (C::cmp(thresh, dis)) {
115
104
  heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
116
105
  thresh = heap_dis[0];
@@ -122,11 +111,10 @@ struct HeapResultHandler {
122
111
  /// series of results for queries i0..i1 is done
123
112
  void end_multiple() {
124
113
  // maybe parallel for
125
- for(size_t i = i0; i < i1; i++) {
126
- heap_reorder<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
114
+ for (size_t i = i0; i < i1; i++) {
115
+ heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
127
116
  }
128
117
  }
129
-
130
118
  };
131
119
 
132
120
  /*****************************************************************
@@ -138,31 +126,25 @@ struct HeapResultHandler {
138
126
  * distance array.
139
127
  *****************************************************************/
140
128
 
141
-
142
-
143
129
  /// Reservoir for a single query
144
- template<class C>
130
+ template <class C>
145
131
  struct ReservoirTopN {
146
132
  using T = typename C::T;
147
133
  using TI = typename C::TI;
148
134
 
149
- T *vals;
150
- TI *ids;
135
+ T* vals;
136
+ TI* ids;
151
137
 
152
- size_t i; // number of stored elements
153
- size_t n; // number of requested elements
154
- size_t capacity; // size of storage
138
+ size_t i; // number of stored elements
139
+ size_t n; // number of requested elements
140
+ size_t capacity; // size of storage
155
141
 
156
142
  T threshold; // current threshold
157
143
 
158
144
  ReservoirTopN() {}
159
145
 
160
- ReservoirTopN(
161
- size_t n, size_t capacity,
162
- T *vals, TI *ids
163
- ):
164
- vals(vals), ids(ids),
165
- i(0), n(n), capacity(capacity) {
146
+ ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
147
+ : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
166
148
  assert(n < capacity);
167
149
  threshold = C::neutral();
168
150
  }
@@ -184,55 +166,47 @@ struct ReservoirTopN {
184
166
  assert(i == capacity);
185
167
 
186
168
  threshold = partition_fuzzy<C>(
187
- vals, ids, capacity, n, (capacity + n) / 2,
188
- &i);
169
+ vals, ids, capacity, n, (capacity + n) / 2, &i);
189
170
  }
190
171
 
191
- void to_result(T *heap_dis, TI *heap_ids) const {
192
-
172
+ void to_result(T* heap_dis, TI* heap_ids) const {
193
173
  for (int j = 0; j < std::min(i, n); j++) {
194
- heap_push<C>(
195
- j + 1, heap_dis, heap_ids,
196
- vals[j], ids[j]
197
- );
174
+ heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
198
175
  }
199
176
 
200
177
  if (i < n) {
201
- heap_reorder<C> (i, heap_dis, heap_ids);
178
+ heap_reorder<C>(i, heap_dis, heap_ids);
202
179
  // add empty results
203
- heap_heapify<C> (n - i, heap_dis + i, heap_ids + i);
180
+ heap_heapify<C>(n - i, heap_dis + i, heap_ids + i);
204
181
  } else {
205
182
  // add remaining elements
206
- heap_addn<C> (n, heap_dis, heap_ids, vals + n, ids + n, i - n);
207
- heap_reorder<C> (n, heap_dis, heap_ids);
183
+ heap_addn<C>(n, heap_dis, heap_ids, vals + n, ids + n, i - n);
184
+ heap_reorder<C>(n, heap_dis, heap_ids);
208
185
  }
209
-
210
186
  }
211
-
212
187
  };
213
188
 
214
-
215
-
216
- template<class C>
189
+ template <class C>
217
190
  struct ReservoirResultHandler {
218
-
219
191
  using T = typename C::T;
220
192
  using TI = typename C::TI;
221
193
 
222
194
  int nq;
223
- T *heap_dis_tab;
224
- TI *heap_ids_tab;
195
+ T* heap_dis_tab;
196
+ TI* heap_ids_tab;
225
197
 
226
- int64_t k; // number of results to keep
198
+ int64_t k; // number of results to keep
227
199
  size_t capacity; // capacity of the reservoirs
228
200
 
229
201
  ReservoirResultHandler(
230
- size_t nq,
231
- T * heap_dis_tab, TI * heap_ids_tab,
232
- size_t k):
233
- nq(nq),
234
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
235
- {
202
+ size_t nq,
203
+ T* heap_dis_tab,
204
+ TI* heap_ids_tab,
205
+ size_t k)
206
+ : nq(nq),
207
+ heap_dis_tab(heap_dis_tab),
208
+ heap_ids_tab(heap_ids_tab),
209
+ k(k) {
236
210
  // double then round up to multiple of 16 (for SIMD alignment)
237
211
  capacity = (2 * k + 15) & ~15;
238
212
  }
@@ -243,23 +217,26 @@ struct ReservoirResultHandler {
243
217
  */
244
218
 
245
219
  struct SingleResultHandler {
246
- ReservoirResultHandler & hr;
220
+ ReservoirResultHandler& hr;
247
221
 
248
222
  std::vector<T> reservoir_dis;
249
223
  std::vector<TI> reservoir_ids;
250
224
  ReservoirTopN<C> res1;
251
225
 
252
- SingleResultHandler(ReservoirResultHandler &hr):
253
- hr(hr), reservoir_dis(hr.capacity), reservoir_ids(hr.capacity)
254
- {
255
- }
226
+ SingleResultHandler(ReservoirResultHandler& hr)
227
+ : hr(hr),
228
+ reservoir_dis(hr.capacity),
229
+ reservoir_ids(hr.capacity) {}
256
230
 
257
231
  size_t i;
258
232
 
259
233
  /// begin results for query # i
260
234
  void begin(size_t i) {
261
235
  res1 = ReservoirTopN<C>(
262
- hr.k, hr.capacity, reservoir_dis.data(), reservoir_ids.data());
236
+ hr.k,
237
+ hr.capacity,
238
+ reservoir_dis.data(),
239
+ reservoir_ids.data());
263
240
  this->i = i;
264
241
  }
265
242
 
@@ -270,8 +247,8 @@ struct ReservoirResultHandler {
270
247
 
271
248
  /// series of results for query i is done
272
249
  void end() {
273
- T * heap_dis = hr.heap_dis_tab + i * hr.k;
274
- TI * heap_ids = hr.heap_ids_tab + i * hr.k;
250
+ T* heap_dis = hr.heap_dis_tab + i * hr.k;
251
+ TI* heap_ids = hr.heap_ids_tab + i * hr.k;
275
252
  res1.to_result(heap_dis, heap_ids);
276
253
  }
277
254
  };
@@ -295,20 +272,22 @@ struct ReservoirResultHandler {
295
272
  reservoirs.clear();
296
273
  for (size_t i = i0; i < i1; i++) {
297
274
  reservoirs.emplace_back(
298
- k, capacity,
299
- reservoir_dis.data() + (i - i0) * capacity,
300
- reservoir_ids.data() + (i - i0) * capacity
301
- );
275
+ k,
276
+ capacity,
277
+ reservoir_dis.data() + (i - i0) * capacity,
278
+ reservoir_ids.data() + (i - i0) * capacity);
302
279
  }
303
280
  }
304
281
 
305
282
  /// add results for query i0..i1 and j0..j1
306
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
283
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
307
284
  // maybe parallel for
308
- for (size_t i = i0; i < i1; i++) {
309
- ReservoirTopN<C> & reservoir = reservoirs[i - i0];
285
+ #pragma omp parallel for
286
+ for (int64_t i = i0; i < i1; i++) {
287
+ ReservoirTopN<C>& reservoir = reservoirs[i - i0];
288
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
310
289
  for (size_t j = j0; j < j1; j++) {
311
- T dis = *dis_tab++;
290
+ T dis = dis_tab_i[j];
312
291
  reservoir.add(dis, j);
313
292
  }
314
293
  }
@@ -317,32 +296,27 @@ struct ReservoirResultHandler {
317
296
  /// series of results for queries i0..i1 is done
318
297
  void end_multiple() {
319
298
  // maybe parallel for
320
- for(size_t i = i0; i < i1; i++) {
299
+ for (size_t i = i0; i < i1; i++) {
321
300
  reservoirs[i - i0].to_result(
322
- heap_dis_tab + i * k, heap_ids_tab + i * k);
301
+ heap_dis_tab + i * k, heap_ids_tab + i * k);
323
302
  }
324
303
  }
325
-
326
304
  };
327
305
 
328
-
329
306
  /*****************************************************************
330
307
  * Result handler for range searches
331
308
  *****************************************************************/
332
309
 
333
-
334
-
335
- template<class C>
310
+ template <class C>
336
311
  struct RangeSearchResultHandler {
337
312
  using T = typename C::T;
338
313
  using TI = typename C::TI;
339
314
 
340
- RangeSearchResult *res;
315
+ RangeSearchResult* res;
341
316
  float radius;
342
317
 
343
- RangeSearchResultHandler(RangeSearchResult *res, float radius):
344
- res(res), radius(radius)
345
- {}
318
+ RangeSearchResultHandler(RangeSearchResult* res, float radius)
319
+ : res(res), radius(radius) {}
346
320
 
347
321
  /******************************************************
348
322
  * API for 1 result at a time (each SingleResultHandler is
@@ -353,11 +327,10 @@ struct RangeSearchResultHandler {
353
327
  // almost the same interface as RangeSearchResultHandler
354
328
  RangeSearchPartialResult pres;
355
329
  float radius;
356
- RangeQueryResult *qr = nullptr;
330
+ RangeQueryResult* qr = nullptr;
357
331
 
358
- SingleResultHandler(RangeSearchResultHandler &rh):
359
- pres(rh.res), radius(rh.radius)
360
- {}
332
+ SingleResultHandler(RangeSearchResultHandler& rh)
333
+ : pres(rh.res), radius(rh.radius) {}
361
334
 
362
335
  /// begin results for query # i
363
336
  void begin(size_t i) {
@@ -366,15 +339,13 @@ struct RangeSearchResultHandler {
366
339
 
367
340
  /// add one result for query i
368
341
  void add_result(T dis, TI idx) {
369
-
370
342
  if (C::cmp(radius, dis)) {
371
343
  qr->add(dis, idx);
372
344
  }
373
345
  }
374
346
 
375
347
  /// series of results for query i is done
376
- void end() {
377
- }
348
+ void end() {}
378
349
 
379
350
  ~SingleResultHandler() {
380
351
  pres.finalize();
@@ -387,8 +358,8 @@ struct RangeSearchResultHandler {
387
358
 
388
359
  size_t i0, i1;
389
360
 
390
- std::vector <RangeSearchPartialResult *> partial_results;
391
- std::vector <size_t> j0s;
361
+ std::vector<RangeSearchPartialResult*> partial_results;
362
+ std::vector<size_t> j0s;
392
363
  int pr = 0;
393
364
 
394
365
  /// begin
@@ -399,8 +370,8 @@ struct RangeSearchResultHandler {
399
370
 
400
371
  /// add results for query i0..i1 and j0..j1
401
372
 
402
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
403
- RangeSearchPartialResult *pres;
373
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
374
+ RangeSearchPartialResult* pres;
404
375
  // there is one RangeSearchPartialResult structure per j0
405
376
  // (= block of columns of the large distance matrix)
406
377
  // it is a bit tricky to find the poper PartialResult structure
@@ -414,39 +385,32 @@ struct RangeSearchResultHandler {
414
385
  pres = partial_results[pr];
415
386
  pr++;
416
387
  } else { // did not find this j0
417
- pres = new RangeSearchPartialResult (res);
388
+ pres = new RangeSearchPartialResult(res);
418
389
  partial_results.push_back(pres);
419
390
  j0s.push_back(j0);
420
391
  pr = partial_results.size();
421
392
  }
422
393
 
423
394
  for (size_t i = i0; i < i1; i++) {
424
- const float *ip_line = dis_tab + (i - i0) * (j1 - j0);
425
- RangeQueryResult & qres = pres->new_result (i);
395
+ const float* ip_line = dis_tab + (i - i0) * (j1 - j0);
396
+ RangeQueryResult& qres = pres->new_result(i);
426
397
 
427
398
  for (size_t j = j0; j < j1; j++) {
428
399
  float dis = *ip_line++;
429
400
  if (C::cmp(radius, dis)) {
430
- qres.add (dis, j);
401
+ qres.add(dis, j);
431
402
  }
432
403
  }
433
404
  }
434
405
  }
435
406
 
436
- void end_multiple() {
437
-
438
- }
407
+ void end_multiple() {}
439
408
 
440
409
  ~RangeSearchResultHandler() {
441
410
  if (partial_results.size() > 0) {
442
- RangeSearchPartialResult::merge (partial_results);
411
+ RangeSearchPartialResult::merge(partial_results);
443
412
  }
444
413
  }
445
-
446
414
  };
447
415
 
448
-
449
-
450
-
451
- } // namespace faiss
452
-
416
+ } // namespace faiss