faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -5,49 +5,38 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  /*
10
9
  * Structures that collect search results from distance computations
11
10
  */
12
11
 
13
12
  #pragma once
14
13
 
15
-
14
+ #include <faiss/impl/AuxIndexStructures.h>
16
15
  #include <faiss/utils/Heap.h>
17
16
  #include <faiss/utils/partitioning.h>
18
- #include <faiss/impl/AuxIndexStructures.h>
19
-
20
17
 
21
18
  namespace faiss {
22
19
 
23
-
24
-
25
20
  /*****************************************************************
26
21
  * Heap based result handler
27
22
  *****************************************************************/
28
23
 
29
-
30
- template<class C>
24
+ template <class C>
31
25
  struct HeapResultHandler {
32
-
33
26
  using T = typename C::T;
34
27
  using TI = typename C::TI;
35
28
 
36
29
  int nq;
37
- T *heap_dis_tab;
38
- TI *heap_ids_tab;
30
+ T* heap_dis_tab;
31
+ TI* heap_ids_tab;
39
32
 
40
- int64_t k; // number of results to keep
33
+ int64_t k; // number of results to keep
41
34
 
42
- HeapResultHandler(
43
- size_t nq,
44
- T * heap_dis_tab, TI * heap_ids_tab,
45
- size_t k):
46
- nq(nq),
47
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
48
- {
49
-
50
- }
35
+ HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k)
36
+ : nq(nq),
37
+ heap_dis_tab(heap_dis_tab),
38
+ heap_ids_tab(heap_ids_tab),
39
+ k(k) {}
51
40
 
52
41
  /******************************************************
53
42
  * API for 1 result at a time (each SingleResultHandler is
@@ -55,20 +44,20 @@ struct HeapResultHandler {
55
44
  */
56
45
 
57
46
  struct SingleResultHandler {
58
- HeapResultHandler & hr;
47
+ HeapResultHandler& hr;
59
48
  size_t k;
60
49
 
61
- T *heap_dis;
62
- TI *heap_ids;
50
+ T* heap_dis;
51
+ TI* heap_ids;
63
52
  T thresh;
64
53
 
65
- SingleResultHandler(HeapResultHandler &hr): hr(hr), k(hr.k) {}
54
+ SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {}
66
55
 
67
56
  /// begin results for query # i
68
57
  void begin(size_t i) {
69
58
  heap_dis = hr.heap_dis_tab + i * k;
70
59
  heap_ids = hr.heap_ids_tab + i * k;
71
- heap_heapify<C> (k, heap_dis, heap_ids);
60
+ heap_heapify<C>(k, heap_dis, heap_ids);
72
61
  thresh = heap_dis[0];
73
62
  }
74
63
 
@@ -82,11 +71,10 @@ struct HeapResultHandler {
82
71
 
83
72
  /// series of results for query i is done
84
73
  void end() {
85
- heap_reorder<C> (k, heap_dis, heap_ids);
74
+ heap_reorder<C>(k, heap_dis, heap_ids);
86
75
  }
87
76
  };
88
77
 
89
-
90
78
  /******************************************************
91
79
  * API for multiple results (called from 1 thread)
92
80
  */
@@ -97,20 +85,21 @@ struct HeapResultHandler {
97
85
  void begin_multiple(size_t i0, size_t i1) {
98
86
  this->i0 = i0;
99
87
  this->i1 = i1;
100
- for(size_t i = i0; i < i1; i++) {
101
- heap_heapify<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
88
+ for (size_t i = i0; i < i1; i++) {
89
+ heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
102
90
  }
103
91
  }
104
92
 
105
93
  /// add results for query i0..i1 and j0..j1
106
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
107
- // maybe parallel for
108
- for (size_t i = i0; i < i1; i++) {
109
- T * heap_dis = heap_dis_tab + i * k;
110
- TI * heap_ids = heap_ids_tab + i * k;
94
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
95
+ #pragma omp parallel for
96
+ for (int64_t i = i0; i < i1; i++) {
97
+ T* heap_dis = heap_dis_tab + i * k;
98
+ TI* heap_ids = heap_ids_tab + i * k;
99
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
111
100
  T thresh = heap_dis[0];
112
101
  for (size_t j = j0; j < j1; j++) {
113
- T dis = *dis_tab++;
102
+ T dis = dis_tab_i[j];
114
103
  if (C::cmp(thresh, dis)) {
115
104
  heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
116
105
  thresh = heap_dis[0];
@@ -122,11 +111,10 @@ struct HeapResultHandler {
122
111
  /// series of results for queries i0..i1 is done
123
112
  void end_multiple() {
124
113
  // maybe parallel for
125
- for(size_t i = i0; i < i1; i++) {
126
- heap_reorder<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
114
+ for (size_t i = i0; i < i1; i++) {
115
+ heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
127
116
  }
128
117
  }
129
-
130
118
  };
131
119
 
132
120
  /*****************************************************************
@@ -138,31 +126,25 @@ struct HeapResultHandler {
138
126
  * distance array.
139
127
  *****************************************************************/
140
128
 
141
-
142
-
143
129
  /// Reservoir for a single query
144
- template<class C>
130
+ template <class C>
145
131
  struct ReservoirTopN {
146
132
  using T = typename C::T;
147
133
  using TI = typename C::TI;
148
134
 
149
- T *vals;
150
- TI *ids;
135
+ T* vals;
136
+ TI* ids;
151
137
 
152
- size_t i; // number of stored elements
153
- size_t n; // number of requested elements
154
- size_t capacity; // size of storage
138
+ size_t i; // number of stored elements
139
+ size_t n; // number of requested elements
140
+ size_t capacity; // size of storage
155
141
 
156
142
  T threshold; // current threshold
157
143
 
158
144
  ReservoirTopN() {}
159
145
 
160
- ReservoirTopN(
161
- size_t n, size_t capacity,
162
- T *vals, TI *ids
163
- ):
164
- vals(vals), ids(ids),
165
- i(0), n(n), capacity(capacity) {
146
+ ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
147
+ : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
166
148
  assert(n < capacity);
167
149
  threshold = C::neutral();
168
150
  }
@@ -184,55 +166,47 @@ struct ReservoirTopN {
184
166
  assert(i == capacity);
185
167
 
186
168
  threshold = partition_fuzzy<C>(
187
- vals, ids, capacity, n, (capacity + n) / 2,
188
- &i);
169
+ vals, ids, capacity, n, (capacity + n) / 2, &i);
189
170
  }
190
171
 
191
- void to_result(T *heap_dis, TI *heap_ids) const {
192
-
172
+ void to_result(T* heap_dis, TI* heap_ids) const {
193
173
  for (int j = 0; j < std::min(i, n); j++) {
194
- heap_push<C>(
195
- j + 1, heap_dis, heap_ids,
196
- vals[j], ids[j]
197
- );
174
+ heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
198
175
  }
199
176
 
200
177
  if (i < n) {
201
- heap_reorder<C> (i, heap_dis, heap_ids);
178
+ heap_reorder<C>(i, heap_dis, heap_ids);
202
179
  // add empty results
203
- heap_heapify<C> (n - i, heap_dis + i, heap_ids + i);
180
+ heap_heapify<C>(n - i, heap_dis + i, heap_ids + i);
204
181
  } else {
205
182
  // add remaining elements
206
- heap_addn<C> (n, heap_dis, heap_ids, vals + n, ids + n, i - n);
207
- heap_reorder<C> (n, heap_dis, heap_ids);
183
+ heap_addn<C>(n, heap_dis, heap_ids, vals + n, ids + n, i - n);
184
+ heap_reorder<C>(n, heap_dis, heap_ids);
208
185
  }
209
-
210
186
  }
211
-
212
187
  };
213
188
 
214
-
215
-
216
- template<class C>
189
+ template <class C>
217
190
  struct ReservoirResultHandler {
218
-
219
191
  using T = typename C::T;
220
192
  using TI = typename C::TI;
221
193
 
222
194
  int nq;
223
- T *heap_dis_tab;
224
- TI *heap_ids_tab;
195
+ T* heap_dis_tab;
196
+ TI* heap_ids_tab;
225
197
 
226
- int64_t k; // number of results to keep
198
+ int64_t k; // number of results to keep
227
199
  size_t capacity; // capacity of the reservoirs
228
200
 
229
201
  ReservoirResultHandler(
230
- size_t nq,
231
- T * heap_dis_tab, TI * heap_ids_tab,
232
- size_t k):
233
- nq(nq),
234
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
235
- {
202
+ size_t nq,
203
+ T* heap_dis_tab,
204
+ TI* heap_ids_tab,
205
+ size_t k)
206
+ : nq(nq),
207
+ heap_dis_tab(heap_dis_tab),
208
+ heap_ids_tab(heap_ids_tab),
209
+ k(k) {
236
210
  // double then round up to multiple of 16 (for SIMD alignment)
237
211
  capacity = (2 * k + 15) & ~15;
238
212
  }
@@ -243,23 +217,26 @@ struct ReservoirResultHandler {
243
217
  */
244
218
 
245
219
  struct SingleResultHandler {
246
- ReservoirResultHandler & hr;
220
+ ReservoirResultHandler& hr;
247
221
 
248
222
  std::vector<T> reservoir_dis;
249
223
  std::vector<TI> reservoir_ids;
250
224
  ReservoirTopN<C> res1;
251
225
 
252
- SingleResultHandler(ReservoirResultHandler &hr):
253
- hr(hr), reservoir_dis(hr.capacity), reservoir_ids(hr.capacity)
254
- {
255
- }
226
+ SingleResultHandler(ReservoirResultHandler& hr)
227
+ : hr(hr),
228
+ reservoir_dis(hr.capacity),
229
+ reservoir_ids(hr.capacity) {}
256
230
 
257
231
  size_t i;
258
232
 
259
233
  /// begin results for query # i
260
234
  void begin(size_t i) {
261
235
  res1 = ReservoirTopN<C>(
262
- hr.k, hr.capacity, reservoir_dis.data(), reservoir_ids.data());
236
+ hr.k,
237
+ hr.capacity,
238
+ reservoir_dis.data(),
239
+ reservoir_ids.data());
263
240
  this->i = i;
264
241
  }
265
242
 
@@ -270,8 +247,8 @@ struct ReservoirResultHandler {
270
247
 
271
248
  /// series of results for query i is done
272
249
  void end() {
273
- T * heap_dis = hr.heap_dis_tab + i * hr.k;
274
- TI * heap_ids = hr.heap_ids_tab + i * hr.k;
250
+ T* heap_dis = hr.heap_dis_tab + i * hr.k;
251
+ TI* heap_ids = hr.heap_ids_tab + i * hr.k;
275
252
  res1.to_result(heap_dis, heap_ids);
276
253
  }
277
254
  };
@@ -295,20 +272,22 @@ struct ReservoirResultHandler {
295
272
  reservoirs.clear();
296
273
  for (size_t i = i0; i < i1; i++) {
297
274
  reservoirs.emplace_back(
298
- k, capacity,
299
- reservoir_dis.data() + (i - i0) * capacity,
300
- reservoir_ids.data() + (i - i0) * capacity
301
- );
275
+ k,
276
+ capacity,
277
+ reservoir_dis.data() + (i - i0) * capacity,
278
+ reservoir_ids.data() + (i - i0) * capacity);
302
279
  }
303
280
  }
304
281
 
305
282
  /// add results for query i0..i1 and j0..j1
306
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
283
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
307
284
  // maybe parallel for
308
- for (size_t i = i0; i < i1; i++) {
309
- ReservoirTopN<C> & reservoir = reservoirs[i - i0];
285
+ #pragma omp parallel for
286
+ for (int64_t i = i0; i < i1; i++) {
287
+ ReservoirTopN<C>& reservoir = reservoirs[i - i0];
288
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
310
289
  for (size_t j = j0; j < j1; j++) {
311
- T dis = *dis_tab++;
290
+ T dis = dis_tab_i[j];
312
291
  reservoir.add(dis, j);
313
292
  }
314
293
  }
@@ -317,32 +296,27 @@ struct ReservoirResultHandler {
317
296
  /// series of results for queries i0..i1 is done
318
297
  void end_multiple() {
319
298
  // maybe parallel for
320
- for(size_t i = i0; i < i1; i++) {
299
+ for (size_t i = i0; i < i1; i++) {
321
300
  reservoirs[i - i0].to_result(
322
- heap_dis_tab + i * k, heap_ids_tab + i * k);
301
+ heap_dis_tab + i * k, heap_ids_tab + i * k);
323
302
  }
324
303
  }
325
-
326
304
  };
327
305
 
328
-
329
306
  /*****************************************************************
330
307
  * Result handler for range searches
331
308
  *****************************************************************/
332
309
 
333
-
334
-
335
- template<class C>
310
+ template <class C>
336
311
  struct RangeSearchResultHandler {
337
312
  using T = typename C::T;
338
313
  using TI = typename C::TI;
339
314
 
340
- RangeSearchResult *res;
315
+ RangeSearchResult* res;
341
316
  float radius;
342
317
 
343
- RangeSearchResultHandler(RangeSearchResult *res, float radius):
344
- res(res), radius(radius)
345
- {}
318
+ RangeSearchResultHandler(RangeSearchResult* res, float radius)
319
+ : res(res), radius(radius) {}
346
320
 
347
321
  /******************************************************
348
322
  * API for 1 result at a time (each SingleResultHandler is
@@ -353,11 +327,10 @@ struct RangeSearchResultHandler {
353
327
  // almost the same interface as RangeSearchResultHandler
354
328
  RangeSearchPartialResult pres;
355
329
  float radius;
356
- RangeQueryResult *qr = nullptr;
330
+ RangeQueryResult* qr = nullptr;
357
331
 
358
- SingleResultHandler(RangeSearchResultHandler &rh):
359
- pres(rh.res), radius(rh.radius)
360
- {}
332
+ SingleResultHandler(RangeSearchResultHandler& rh)
333
+ : pres(rh.res), radius(rh.radius) {}
361
334
 
362
335
  /// begin results for query # i
363
336
  void begin(size_t i) {
@@ -366,15 +339,13 @@ struct RangeSearchResultHandler {
366
339
 
367
340
  /// add one result for query i
368
341
  void add_result(T dis, TI idx) {
369
-
370
342
  if (C::cmp(radius, dis)) {
371
343
  qr->add(dis, idx);
372
344
  }
373
345
  }
374
346
 
375
347
  /// series of results for query i is done
376
- void end() {
377
- }
348
+ void end() {}
378
349
 
379
350
  ~SingleResultHandler() {
380
351
  pres.finalize();
@@ -387,8 +358,8 @@ struct RangeSearchResultHandler {
387
358
 
388
359
  size_t i0, i1;
389
360
 
390
- std::vector <RangeSearchPartialResult *> partial_results;
391
- std::vector <size_t> j0s;
361
+ std::vector<RangeSearchPartialResult*> partial_results;
362
+ std::vector<size_t> j0s;
392
363
  int pr = 0;
393
364
 
394
365
  /// begin
@@ -399,8 +370,8 @@ struct RangeSearchResultHandler {
399
370
 
400
371
  /// add results for query i0..i1 and j0..j1
401
372
 
402
- void add_results(size_t j0, size_t j1, const T *dis_tab) {
403
- RangeSearchPartialResult *pres;
373
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
374
+ RangeSearchPartialResult* pres;
404
375
  // there is one RangeSearchPartialResult structure per j0
405
376
  // (= block of columns of the large distance matrix)
406
377
  // it is a bit tricky to find the poper PartialResult structure
@@ -414,39 +385,32 @@ struct RangeSearchResultHandler {
414
385
  pres = partial_results[pr];
415
386
  pr++;
416
387
  } else { // did not find this j0
417
- pres = new RangeSearchPartialResult (res);
388
+ pres = new RangeSearchPartialResult(res);
418
389
  partial_results.push_back(pres);
419
390
  j0s.push_back(j0);
420
391
  pr = partial_results.size();
421
392
  }
422
393
 
423
394
  for (size_t i = i0; i < i1; i++) {
424
- const float *ip_line = dis_tab + (i - i0) * (j1 - j0);
425
- RangeQueryResult & qres = pres->new_result (i);
395
+ const float* ip_line = dis_tab + (i - i0) * (j1 - j0);
396
+ RangeQueryResult& qres = pres->new_result(i);
426
397
 
427
398
  for (size_t j = j0; j < j1; j++) {
428
399
  float dis = *ip_line++;
429
400
  if (C::cmp(radius, dis)) {
430
- qres.add (dis, j);
401
+ qres.add(dis, j);
431
402
  }
432
403
  }
433
404
  }
434
405
  }
435
406
 
436
- void end_multiple() {
437
-
438
- }
407
+ void end_multiple() {}
439
408
 
440
409
  ~RangeSearchResultHandler() {
441
410
  if (partial_results.size() > 0) {
442
- RangeSearchPartialResult::merge (partial_results);
411
+ RangeSearchPartialResult::merge(partial_results);
443
412
  }
444
413
  }
445
-
446
414
  };
447
415
 
448
-
449
-
450
-
451
- } // namespace faiss
452
-
416
+ } // namespace faiss