faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,130 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstdint>
11
+ #include <vector>
12
+
13
+ #include <faiss/Clustering.h>
14
+ #include <faiss/impl/AdditiveQuantizer.h>
15
+
16
+ namespace faiss {
17
+
18
+ /** Residual quantizer with variable number of bits per sub-quantizer
19
+ *
20
+ * The residual centroids are stored in a big cumulative centroid table.
21
+ * The codes are represented either as a non-compact table of size (n, M) or
22
+ * as the compact output (n, code_size).
23
+ */
24
+
25
+ struct ResidualQuantizer : AdditiveQuantizer {
26
+ /// initialization
27
+ enum train_type_t {
28
+ Train_default, ///< regular k-means
29
+ Train_progressive_dim, ///< progressive dim clustering
30
+ };
31
+
32
+ // set this bit on train_type if beam is to be trained only on the
33
+ // first element of the beam (faster but less accurate)
34
+ static const int Train_top_beam = 1024;
35
+ train_type_t train_type;
36
+
37
+ /// beam size used for training and for encoding
38
+ int max_beam_size;
39
+
40
+ /// distance matrixes with beam search can get large, so use this
41
+ /// to batch computations at encoding time.
42
+ size_t max_mem_distances;
43
+
44
+ /// clustering parameters
45
+ ProgressiveDimClusteringParameters cp;
46
+
47
+ /// if non-NULL, use this index for assignment
48
+ ProgressiveDimIndexFactory* assign_index_factory;
49
+
50
+ ResidualQuantizer(size_t d, const std::vector<size_t>& nbits);
51
+
52
+ ResidualQuantizer(
53
+ size_t d, /* dimensionality of the input vectors */
54
+ size_t M, /* number of subquantizers */
55
+ size_t nbits); /* number of bit per subvector index */
56
+
57
+ ResidualQuantizer();
58
+
59
+ // Train the residual quantizer
60
+ void train(size_t n, const float* x) override;
61
+
62
+ /** Encode a set of vectors
63
+ *
64
+ * @param x vectors to encode, size n * d
65
+ * @param codes output codes, size n * code_size
66
+ */
67
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
68
+
69
+ /** lower-level encode function
70
+ *
71
+ * @param n number of vectors to hanlde
72
+ * @param residuals vectors to encode, size (n, beam_size, d)
73
+ * @param beam_size input beam size
74
+ * @param new_beam_size output beam size (should be <= K * beam_size)
75
+ * @param new_codes output codes, size (n, new_beam_size, m + 1)
76
+ * @param new_residuals output residuals, size (n, new_beam_size, d)
77
+ * @param new_distances output distances, size (n, new_beam_size)
78
+ */
79
+ void refine_beam(
80
+ size_t n,
81
+ size_t beam_size,
82
+ const float* residuals,
83
+ int new_beam_size,
84
+ int32_t* new_codes,
85
+ float* new_residuals = nullptr,
86
+ float* new_distances = nullptr) const;
87
+
88
+ /** Beam search can consume a lot of memory. This function estimates the
89
+ * amount of mem used by refine_beam to adjust the batch size
90
+ *
91
+ * @param beam_size if != -1, override the beam size
92
+ */
93
+ size_t memory_per_point(int beam_size = -1) const;
94
+ };
95
+
96
+ /** Encode a residual by sampling from a centroid table.
97
+ *
98
+ * This is a single encoding step the residual quantizer.
99
+ * It allows low-level access to the encoding function, exposed mainly for unit
100
+ * tests.
101
+ *
102
+ * @param n number of vectors to hanlde
103
+ * @param residuals vectors to encode, size (n, beam_size, d)
104
+ * @param cent centroids, size (K, d)
105
+ * @param beam_size input beam size
106
+ * @param m size of the codes for the previous encoding steps
107
+ * @param codes code array for the previous steps of the beam (n,
108
+ * beam_size, m)
109
+ * @param new_beam_size output beam size (should be <= K * beam_size)
110
+ * @param new_codes output codes, size (n, new_beam_size, m + 1)
111
+ * @param new_residuals output residuals, size (n, new_beam_size, d)
112
+ * @param new_distances output distances, size (n, new_beam_size)
113
+ * @param assign_index if non-NULL, will be used to perform assignment
114
+ */
115
+ void beam_search_encode_step(
116
+ size_t d,
117
+ size_t K,
118
+ const float* cent,
119
+ size_t n,
120
+ size_t beam_size,
121
+ const float* residuals,
122
+ size_t m,
123
+ const int32_t* codes,
124
+ size_t new_beam_size,
125
+ int32_t* new_codes,
126
+ float* new_residuals,
127
+ float* new_distances,
128
+ Index* assign_index = nullptr);
129
+
130
+ }; // namespace faiss
@@ -5,49 +5,38 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  /*
10
9
  * Structures that collect search results from distance computations
11
10
  */
12
11
 
13
12
  #pragma once
14
13
 
15
-
14
+ #include <faiss/impl/AuxIndexStructures.h>
16
15
  #include <faiss/utils/Heap.h>
17
16
  #include <faiss/utils/partitioning.h>
18
- #include <faiss/impl/AuxIndexStructures.h>
19
-
20
17
 
21
18
  namespace faiss {
22
19
 
23
-
24
-
25
20
  /*****************************************************************
26
21
  * Heap based result handler
27
22
  *****************************************************************/
28
23
 
29
-
30
- template<class C>
24
+ template <class C>
31
25
  struct HeapResultHandler {
32
-
33
26
  using T = typename C::T;
34
27
  using TI = typename C::TI;
35
28
 
36
29
  int nq;
37
- T *heap_dis_tab;
38
- TI *heap_ids_tab;
30
+ T* heap_dis_tab;
31
+ TI* heap_ids_tab;
39
32
 
40
- int64_t k; // number of results to keep
33
+ int64_t k; // number of results to keep
41
34
 
42
- HeapResultHandler(
43
- size_t nq,
44
- T * heap_dis_tab, TI * heap_ids_tab,
45
- size_t k):
46
- nq(nq),
47
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
48
- {
49
-
50
- }
35
+ HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k)
36
+ : nq(nq),
37
+ heap_dis_tab(heap_dis_tab),
38
+ heap_ids_tab(heap_ids_tab),
39
+ k(k) {}
51
40
 
52
41
  /******************************************************
53
42
  * API for 1 result at a time (each SingleResultHandler is
@@ -55,20 +44,20 @@ struct HeapResultHandler {
55
44
  */
56
45
 
57
46
  struct SingleResultHandler {
58
- HeapResultHandler & hr;
47
+ HeapResultHandler& hr;
59
48
  size_t k;
60
49
 
61
- T *heap_dis;
62
- TI *heap_ids;
50
+ T* heap_dis;
51
+ TI* heap_ids;
63
52
  T thresh;
64
53
 
65
- SingleResultHandler(HeapResultHandler &hr): hr(hr), k(hr.k) {}
54
+ SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {}
66
55
 
67
56
  /// begin results for query # i
68
57
  void begin(size_t i) {
69
58
  heap_dis = hr.heap_dis_tab + i * k;
70
59
  heap_ids = hr.heap_ids_tab + i * k;
71
- heap_heapify<C> (k, heap_dis, heap_ids);
60
+ heap_heapify<C>(k, heap_dis, heap_ids);
72
61
  thresh = heap_dis[0];
73
62
  }
74
63
 
@@ -82,11 +71,10 @@ struct HeapResultHandler {
82
71
 
83
72
  /// series of results for query i is done
84
73
  void end() {
85
- heap_reorder<C> (k, heap_dis, heap_ids);
74
+ heap_reorder<C>(k, heap_dis, heap_ids);
86
75
  }
87
76
  };
88
77
 
89
-
90
78
  /******************************************************
91
79
  * API for multiple results (called from 1 thread)
92
80
  */
@@ -97,20 +85,21 @@ struct HeapResultHandler {
97
85
  void begin_multiple(size_t i0, size_t i1) {
98
86
  this->i0 = i0;
99
87
  this->i1 = i1;
100
- for(size_t i = i0; i < i1; i++) {
101
- heap_heapify<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
88
+ for (size_t i = i0; i < i1; i++) {
89
+ heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
102
90
  }
103
91
  }
104
92
 
105
93
  /// add results for query i0..i1 and j0..j1
106
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
107
- // maybe parallel for
108
- for (size_t i = i0; i < i1; i++) {
109
- T * heap_dis = heap_dis_tab + i * k;
110
- TI * heap_ids = heap_ids_tab + i * k;
94
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
95
+ #pragma omp parallel for
96
+ for (int64_t i = i0; i < i1; i++) {
97
+ T* heap_dis = heap_dis_tab + i * k;
98
+ TI* heap_ids = heap_ids_tab + i * k;
99
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
111
100
  T thresh = heap_dis[0];
112
101
  for (size_t j = j0; j < j1; j++) {
113
- T dis = *dis_tab++;
102
+ T dis = dis_tab_i[j];
114
103
  if (C::cmp(thresh, dis)) {
115
104
  heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
116
105
  thresh = heap_dis[0];
@@ -122,11 +111,10 @@ struct HeapResultHandler {
122
111
  /// series of results for queries i0..i1 is done
123
112
  void end_multiple() {
124
113
  // maybe parallel for
125
- for(size_t i = i0; i < i1; i++) {
126
- heap_reorder<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
114
+ for (size_t i = i0; i < i1; i++) {
115
+ heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
127
116
  }
128
117
  }
129
-
130
118
  };
131
119
 
132
120
  /*****************************************************************
@@ -138,31 +126,25 @@ struct HeapResultHandler {
138
126
  * distance array.
139
127
  *****************************************************************/
140
128
 
141
-
142
-
143
129
  /// Reservoir for a single query
144
- template<class C>
130
+ template <class C>
145
131
  struct ReservoirTopN {
146
132
  using T = typename C::T;
147
133
  using TI = typename C::TI;
148
134
 
149
- T *vals;
150
- TI *ids;
135
+ T* vals;
136
+ TI* ids;
151
137
 
152
- size_t i; // number of stored elements
153
- size_t n; // number of requested elements
154
- size_t capacity; // size of storage
138
+ size_t i; // number of stored elements
139
+ size_t n; // number of requested elements
140
+ size_t capacity; // size of storage
155
141
 
156
142
  T threshold; // current threshold
157
143
 
158
144
  ReservoirTopN() {}
159
145
 
160
- ReservoirTopN(
161
- size_t n, size_t capacity,
162
- T *vals, TI *ids
163
- ):
164
- vals(vals), ids(ids),
165
- i(0), n(n), capacity(capacity) {
146
+ ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
147
+ : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
166
148
  assert(n < capacity);
167
149
  threshold = C::neutral();
168
150
  }
@@ -184,55 +166,47 @@ struct ReservoirTopN {
184
166
  assert(i == capacity);
185
167
 
186
168
  threshold = partition_fuzzy<C>(
187
- vals, ids, capacity, n, (capacity + n) / 2,
188
- &i);
169
+ vals, ids, capacity, n, (capacity + n) / 2, &i);
189
170
  }
190
171
 
191
- void to_result(T *heap_dis, TI *heap_ids) const {
192
-
172
+ void to_result(T* heap_dis, TI* heap_ids) const {
193
173
  for (int j = 0; j < std::min(i, n); j++) {
194
- heap_push<C>(
195
- j + 1, heap_dis, heap_ids,
196
- vals[j], ids[j]
197
- );
174
+ heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
198
175
  }
199
176
 
200
177
  if (i < n) {
201
- heap_reorder<C> (i, heap_dis, heap_ids);
178
+ heap_reorder<C>(i, heap_dis, heap_ids);
202
179
  // add empty results
203
- heap_heapify<C> (n - i, heap_dis + i, heap_ids + i);
180
+ heap_heapify<C>(n - i, heap_dis + i, heap_ids + i);
204
181
  } else {
205
182
  // add remaining elements
206
- heap_addn<C> (n, heap_dis, heap_ids, vals + n, ids + n, i - n);
207
- heap_reorder<C> (n, heap_dis, heap_ids);
183
+ heap_addn<C>(n, heap_dis, heap_ids, vals + n, ids + n, i - n);
184
+ heap_reorder<C>(n, heap_dis, heap_ids);
208
185
  }
209
-
210
186
  }
211
-
212
187
  };
213
188
 
214
-
215
-
216
- template<class C>
189
+ template <class C>
217
190
  struct ReservoirResultHandler {
218
-
219
191
  using T = typename C::T;
220
192
  using TI = typename C::TI;
221
193
 
222
194
  int nq;
223
- T *heap_dis_tab;
224
- TI *heap_ids_tab;
195
+ T* heap_dis_tab;
196
+ TI* heap_ids_tab;
225
197
 
226
- int64_t k; // number of results to keep
198
+ int64_t k; // number of results to keep
227
199
  size_t capacity; // capacity of the reservoirs
228
200
 
229
201
  ReservoirResultHandler(
230
- size_t nq,
231
- T * heap_dis_tab, TI * heap_ids_tab,
232
- size_t k):
233
- nq(nq),
234
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
235
- {
202
+ size_t nq,
203
+ T* heap_dis_tab,
204
+ TI* heap_ids_tab,
205
+ size_t k)
206
+ : nq(nq),
207
+ heap_dis_tab(heap_dis_tab),
208
+ heap_ids_tab(heap_ids_tab),
209
+ k(k) {
236
210
  // double then round up to multiple of 16 (for SIMD alignment)
237
211
  capacity = (2 * k + 15) & ~15;
238
212
  }
@@ -243,23 +217,26 @@ struct ReservoirResultHandler {
243
217
  */
244
218
 
245
219
  struct SingleResultHandler {
246
- ReservoirResultHandler & hr;
220
+ ReservoirResultHandler& hr;
247
221
 
248
222
  std::vector<T> reservoir_dis;
249
223
  std::vector<TI> reservoir_ids;
250
224
  ReservoirTopN<C> res1;
251
225
 
252
- SingleResultHandler(ReservoirResultHandler &hr):
253
- hr(hr), reservoir_dis(hr.capacity), reservoir_ids(hr.capacity)
254
- {
255
- }
226
+ SingleResultHandler(ReservoirResultHandler& hr)
227
+ : hr(hr),
228
+ reservoir_dis(hr.capacity),
229
+ reservoir_ids(hr.capacity) {}
256
230
 
257
231
  size_t i;
258
232
 
259
233
  /// begin results for query # i
260
234
  void begin(size_t i) {
261
235
  res1 = ReservoirTopN<C>(
262
- hr.k, hr.capacity, reservoir_dis.data(), reservoir_ids.data());
236
+ hr.k,
237
+ hr.capacity,
238
+ reservoir_dis.data(),
239
+ reservoir_ids.data());
263
240
  this->i = i;
264
241
  }
265
242
 
@@ -270,8 +247,8 @@ struct ReservoirResultHandler {
270
247
 
271
248
  /// series of results for query i is done
272
249
  void end() {
273
- T * heap_dis = hr.heap_dis_tab + i * hr.k;
274
- TI * heap_ids = hr.heap_ids_tab + i * hr.k;
250
+ T* heap_dis = hr.heap_dis_tab + i * hr.k;
251
+ TI* heap_ids = hr.heap_ids_tab + i * hr.k;
275
252
  res1.to_result(heap_dis, heap_ids);
276
253
  }
277
254
  };
@@ -295,20 +272,22 @@ struct ReservoirResultHandler {
295
272
  reservoirs.clear();
296
273
  for (size_t i = i0; i < i1; i++) {
297
274
  reservoirs.emplace_back(
298
- k, capacity,
299
- reservoir_dis.data() + (i - i0) * capacity,
300
- reservoir_ids.data() + (i - i0) * capacity
301
- );
275
+ k,
276
+ capacity,
277
+ reservoir_dis.data() + (i - i0) * capacity,
278
+ reservoir_ids.data() + (i - i0) * capacity);
302
279
  }
303
280
  }
304
281
 
305
282
  /// add results for query i0..i1 and j0..j1
306
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
283
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
307
284
  // maybe parallel for
308
- for (size_t i = i0; i < i1; i++) {
309
- ReservoirTopN<C> & reservoir = reservoirs[i - i0];
285
+ #pragma omp parallel for
286
+ for (int64_t i = i0; i < i1; i++) {
287
+ ReservoirTopN<C>& reservoir = reservoirs[i - i0];
288
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
310
289
  for (size_t j = j0; j < j1; j++) {
311
- T dis = *dis_tab++;
290
+ T dis = dis_tab_i[j];
312
291
  reservoir.add(dis, j);
313
292
  }
314
293
  }
@@ -317,32 +296,27 @@ struct ReservoirResultHandler {
317
296
  /// series of results for queries i0..i1 is done
318
297
  void end_multiple() {
319
298
  // maybe parallel for
320
- for(size_t i = i0; i < i1; i++) {
299
+ for (size_t i = i0; i < i1; i++) {
321
300
  reservoirs[i - i0].to_result(
322
- heap_dis_tab + i * k, heap_ids_tab + i * k);
301
+ heap_dis_tab + i * k, heap_ids_tab + i * k);
323
302
  }
324
303
  }
325
-
326
304
  };
327
305
 
328
-
329
306
  /*****************************************************************
330
307
  * Result handler for range searches
331
308
  *****************************************************************/
332
309
 
333
-
334
-
335
- template<class C>
310
+ template <class C>
336
311
  struct RangeSearchResultHandler {
337
312
  using T = typename C::T;
338
313
  using TI = typename C::TI;
339
314
 
340
- RangeSearchResult *res;
315
+ RangeSearchResult* res;
341
316
  float radius;
342
317
 
343
- RangeSearchResultHandler(RangeSearchResult *res, float radius):
344
- res(res), radius(radius)
345
- {}
318
+ RangeSearchResultHandler(RangeSearchResult* res, float radius)
319
+ : res(res), radius(radius) {}
346
320
 
347
321
  /******************************************************
348
322
  * API for 1 result at a time (each SingleResultHandler is
@@ -353,11 +327,10 @@ struct RangeSearchResultHandler {
353
327
  // almost the same interface as RangeSearchResultHandler
354
328
  RangeSearchPartialResult pres;
355
329
  float radius;
356
- RangeQueryResult *qr = nullptr;
330
+ RangeQueryResult* qr = nullptr;
357
331
 
358
- SingleResultHandler(RangeSearchResultHandler &rh):
359
- pres(rh.res), radius(rh.radius)
360
- {}
332
+ SingleResultHandler(RangeSearchResultHandler& rh)
333
+ : pres(rh.res), radius(rh.radius) {}
361
334
 
362
335
  /// begin results for query # i
363
336
  void begin(size_t i) {
@@ -366,15 +339,13 @@ struct RangeSearchResultHandler {
366
339
 
367
340
  /// add one result for query i
368
341
  void add_result(T dis, TI idx) {
369
-
370
342
  if (C::cmp(radius, dis)) {
371
343
  qr->add(dis, idx);
372
344
  }
373
345
  }
374
346
 
375
347
  /// series of results for query i is done
376
- void end() {
377
- }
348
+ void end() {}
378
349
 
379
350
  ~SingleResultHandler() {
380
351
  pres.finalize();
@@ -387,8 +358,8 @@ struct RangeSearchResultHandler {
387
358
 
388
359
  size_t i0, i1;
389
360
 
390
- std::vector <RangeSearchPartialResult *> partial_results;
391
- std::vector <size_t> j0s;
361
+ std::vector<RangeSearchPartialResult*> partial_results;
362
+ std::vector<size_t> j0s;
392
363
  int pr = 0;
393
364
 
394
365
  /// begin
@@ -399,8 +370,8 @@ struct RangeSearchResultHandler {
399
370
 
400
371
  /// add results for query i0..i1 and j0..j1
401
372
 
402
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
403
- RangeSearchPartialResult *pres;
373
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
374
+ RangeSearchPartialResult* pres;
404
375
  // there is one RangeSearchPartialResult structure per j0
405
376
  // (= block of columns of the large distance matrix)
406
377
  // it is a bit tricky to find the poper PartialResult structure
@@ -414,39 +385,32 @@ struct RangeSearchResultHandler {
414
385
  pres = partial_results[pr];
415
386
  pr++;
416
387
  } else { // did not find this j0
417
- pres = new RangeSearchPartialResult (res);
388
+ pres = new RangeSearchPartialResult(res);
418
389
  partial_results.push_back(pres);
419
390
  j0s.push_back(j0);
420
391
  pr = partial_results.size();
421
392
  }
422
393
 
423
394
  for (size_t i = i0; i < i1; i++) {
424
- const float *ip_line = dis_tab + (i - i0) * (j1 - j0);
425
- RangeQueryResult & qres = pres->new_result (i);
395
+ const float* ip_line = dis_tab + (i - i0) * (j1 - j0);
396
+ RangeQueryResult& qres = pres->new_result(i);
426
397
 
427
398
  for (size_t j = j0; j < j1; j++) {
428
399
  float dis = *ip_line++;
429
400
  if (C::cmp(radius, dis)) {
430
- qres.add (dis, j);
401
+ qres.add(dis, j);
431
402
  }
432
403
  }
433
404
  }
434
405
  }
435
406
 
436
- void end_multiple() {
437
-
438
- }
407
+ void end_multiple() {}
439
408
 
440
409
  ~RangeSearchResultHandler() {
441
410
  if (partial_results.size() > 0) {
442
- RangeSearchPartialResult::merge (partial_results);
411
+ RangeSearchPartialResult::merge(partial_results);
443
412
  }
444
413
  }
445
-
446
414
  };
447
415
 
448
-
449
-
450
-
451
- } // namespace faiss
452
-
416
+ } // namespace faiss