faiss 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +26 -1
  3. data/README.md +15 -3
  4. data/ext/faiss/ext.cpp +12 -308
  5. data/ext/faiss/extconf.rb +5 -2
  6. data/ext/faiss/index.cpp +189 -0
  7. data/ext/faiss/index_binary.cpp +75 -0
  8. data/ext/faiss/kmeans.cpp +40 -0
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +33 -0
  11. data/ext/faiss/product_quantizer.cpp +53 -0
  12. data/ext/faiss/utils.cpp +13 -0
  13. data/ext/faiss/utils.h +5 -0
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +31 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -10,30 +10,27 @@
10
10
  #ifndef FAISS_INDEX_IVF_H
11
11
  #define FAISS_INDEX_IVF_H
12
12
 
13
-
14
- #include <vector>
15
- #include <unordered_map>
16
13
  #include <stdint.h>
14
+ #include <unordered_map>
15
+ #include <vector>
17
16
 
18
- #include <faiss/Index.h>
19
- #include <faiss/invlists/InvertedLists.h>
20
- #include <faiss/invlists/DirectMap.h>
21
17
  #include <faiss/Clustering.h>
18
+ #include <faiss/Index.h>
22
19
  #include <faiss/impl/platform_macros.h>
20
+ #include <faiss/invlists/DirectMap.h>
21
+ #include <faiss/invlists/InvertedLists.h>
23
22
  #include <faiss/utils/Heap.h>
24
23
 
25
-
26
24
  namespace faiss {
27
25
 
28
-
29
26
  /** Encapsulates a quantizer object for the IndexIVF
30
27
  *
31
28
  * The class isolates the fields that are independent of the storage
32
29
  * of the lists (especially training)
33
30
  */
34
31
  struct Level1Quantizer {
35
- Index * quantizer; ///< quantizer that maps vectors to inverted lists
36
- size_t nlist; ///< number of possible key values
32
+ Index* quantizer; ///< quantizer that maps vectors to inverted lists
33
+ size_t nlist; ///< number of possible key values
37
34
 
38
35
  /**
39
36
  * = 0: use the quantizer as index in a kmeans training
@@ -41,40 +38,37 @@ struct Level1Quantizer {
41
38
  * = 2: kmeans training on a flat index + add the centroids to the quantizer
42
39
  */
43
40
  char quantizer_trains_alone;
44
- bool own_fields; ///< whether object owns the quantizer
41
+ bool own_fields; ///< whether object owns the quantizer
45
42
 
46
43
  ClusteringParameters cp; ///< to override default clustering params
47
- Index *clustering_index; ///< to override index used during clustering
44
+ Index* clustering_index; ///< to override index used during clustering
48
45
 
49
46
  /// Trains the quantizer and calls train_residual to train sub-quantizers
50
- void train_q1 (size_t n, const float *x, bool verbose,
51
- MetricType metric_type);
52
-
47
+ void train_q1(
48
+ size_t n,
49
+ const float* x,
50
+ bool verbose,
51
+ MetricType metric_type);
53
52
 
54
53
  /// compute the number of bytes required to store list ids
55
- size_t coarse_code_size () const;
56
- void encode_listno (Index::idx_t list_no, uint8_t *code) const;
57
- Index::idx_t decode_listno (const uint8_t *code) const;
54
+ size_t coarse_code_size() const;
55
+ void encode_listno(Index::idx_t list_no, uint8_t* code) const;
56
+ Index::idx_t decode_listno(const uint8_t* code) const;
58
57
 
59
- Level1Quantizer (Index * quantizer, size_t nlist);
58
+ Level1Quantizer(Index* quantizer, size_t nlist);
60
59
 
61
- Level1Quantizer ();
62
-
63
- ~Level1Quantizer ();
60
+ Level1Quantizer();
64
61
 
62
+ ~Level1Quantizer();
65
63
  };
66
64
 
67
-
68
-
69
65
  struct IVFSearchParameters {
70
- size_t nprobe; ///< number of probes at query time
71
- size_t max_codes; ///< max nb of codes to visit to do a query
72
- IVFSearchParameters(): nprobe(1), max_codes(0) {}
73
- virtual ~IVFSearchParameters () {}
66
+ size_t nprobe; ///< number of probes at query time
67
+ size_t max_codes; ///< max nb of codes to visit to do a query
68
+ IVFSearchParameters() : nprobe(1), max_codes(0) {}
69
+ virtual ~IVFSearchParameters() {}
74
70
  };
75
71
 
76
-
77
-
78
72
  struct InvertedListScanner;
79
73
  struct IndexIVFStats;
80
74
 
@@ -98,15 +92,15 @@ struct IndexIVFStats;
98
92
  * Sub-classes implement a post-filtering of the index that refines
99
93
  * the distance estimation from the query to databse vectors.
100
94
  */
101
- struct IndexIVF: Index, Level1Quantizer {
95
+ struct IndexIVF : Index, Level1Quantizer {
102
96
  /// Access to the actual data
103
- InvertedLists *invlists;
97
+ InvertedLists* invlists;
104
98
  bool own_invlists;
105
99
 
106
- size_t code_size; ///< code size per vector in bytes
100
+ size_t code_size; ///< code size per vector in bytes
107
101
 
108
- size_t nprobe; ///< number of probes at query time
109
- size_t max_codes; ///< max nb of codes to visit to do a query
102
+ size_t nprobe; ///< number of probes at query time
103
+ size_t max_codes; ///< max nb of codes to visit to do a query
110
104
 
111
105
  /** Parallel mode determines how queries are parallelized with OpenMP
112
106
  *
@@ -130,9 +124,12 @@ struct IndexIVF: Index, Level1Quantizer {
130
124
  * identifier. The pointer is borrowed: the quantizer should not
131
125
  * be deleted while the IndexIVF is in use.
132
126
  */
133
- IndexIVF (Index * quantizer, size_t d,
134
- size_t nlist, size_t code_size,
135
- MetricType metric = METRIC_L2);
127
+ IndexIVF(
128
+ Index* quantizer,
129
+ size_t d,
130
+ size_t nlist,
131
+ size_t code_size,
132
+ MetricType metric = METRIC_L2);
136
133
 
137
134
  void reset() override;
138
135
 
@@ -145,6 +142,19 @@ struct IndexIVF: Index, Level1Quantizer {
145
142
  /// default implementation that calls encode_vectors
146
143
  void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
147
144
 
145
+ /** Implementation of vector addition where the vector assignments are
146
+ * predefined. The default implementation hands over the code extraction to
147
+ * encode_vectors.
148
+ *
149
+ * @param precomputed_idx quantization indices for the input vectors
150
+ * (size n)
151
+ */
152
+ virtual void add_core(
153
+ idx_t n,
154
+ const float* x,
155
+ const idx_t* xids,
156
+ const idx_t* precomputed_idx);
157
+
148
158
  /** Encodes a set of vectors as they would appear in the inverted lists
149
159
  *
150
160
  * @param list_nos inverted list ids as returned by the
@@ -154,14 +164,16 @@ struct IndexIVF: Index, Level1Quantizer {
154
164
  * include the list ids in the code (in this case add
155
165
  * ceil(log8(nlist)) to the code size)
156
166
  */
157
- virtual void encode_vectors(idx_t n, const float* x,
158
- const idx_t *list_nos,
159
- uint8_t * codes,
160
- bool include_listno = false) const = 0;
167
+ virtual void encode_vectors(
168
+ idx_t n,
169
+ const float* x,
170
+ const idx_t* list_nos,
171
+ uint8_t* codes,
172
+ bool include_listno = false) const = 0;
161
173
 
162
174
  /// Sub-classes that encode the residuals can train their encoders here
163
175
  /// does nothing by default
164
- virtual void train_residual (idx_t n, const float *x);
176
+ virtual void train_residual(idx_t n, const float* x);
165
177
 
166
178
  /** search a set of vectors, that are pre-quantized by the IVF
167
179
  * quantizer. Fill in the corresponding heaps with the query
@@ -182,36 +194,50 @@ struct IndexIVF: Index, Level1Quantizer {
182
194
  * @param params used to override the object's search parameters
183
195
  * @param stats search stats to be updated (can be null)
184
196
  */
185
- virtual void search_preassigned (
186
- idx_t n, const float *x, idx_t k,
187
- const idx_t *assign, const float *centroid_dis,
188
- float *distances, idx_t *labels,
197
+ virtual void search_preassigned(
198
+ idx_t n,
199
+ const float* x,
200
+ idx_t k,
201
+ const idx_t* assign,
202
+ const float* centroid_dis,
203
+ float* distances,
204
+ idx_t* labels,
189
205
  bool store_pairs,
190
- const IVFSearchParameters *params=nullptr,
191
- IndexIVFStats *stats=nullptr
192
- ) const;
206
+ const IVFSearchParameters* params = nullptr,
207
+ IndexIVFStats* stats = nullptr) const;
193
208
 
194
209
  /** assign the vectors, then call search_preassign */
195
- void search (idx_t n, const float *x, idx_t k,
196
- float *distances, idx_t *labels) const override;
197
-
198
- void range_search (idx_t n, const float* x, float radius,
199
- RangeSearchResult* result) const override;
210
+ void search(
211
+ idx_t n,
212
+ const float* x,
213
+ idx_t k,
214
+ float* distances,
215
+ idx_t* labels) const override;
216
+
217
+ void range_search(
218
+ idx_t n,
219
+ const float* x,
220
+ float radius,
221
+ RangeSearchResult* result) const override;
200
222
 
201
223
  void range_search_preassigned(
202
- idx_t nx, const float *x, float radius,
203
- const idx_t *keys, const float *coarse_dis,
204
- RangeSearchResult *result,
205
- bool store_pairs=false,
206
- const IVFSearchParameters *params=nullptr,
207
- IndexIVFStats *stats=nullptr) const;
224
+ idx_t nx,
225
+ const float* x,
226
+ float radius,
227
+ const idx_t* keys,
228
+ const float* coarse_dis,
229
+ RangeSearchResult* result,
230
+ bool store_pairs = false,
231
+ const IVFSearchParameters* params = nullptr,
232
+ IndexIVFStats* stats = nullptr) const;
208
233
 
209
234
  /// get a scanner for this index (store_pairs means ignore labels)
210
- virtual InvertedListScanner *get_InvertedListScanner (
211
- bool store_pairs=false) const;
235
+ virtual InvertedListScanner* get_InvertedListScanner(
236
+ bool store_pairs = false) const;
212
237
 
213
- /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2 */
214
- void reconstruct (idx_t key, float* recons) const override;
238
+ /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
239
+ */
240
+ void reconstruct(idx_t key, float* recons) const override;
215
241
 
216
242
  /** Update a subset of vectors.
217
243
  *
@@ -221,7 +247,7 @@ struct IndexIVF: Index, Level1Quantizer {
221
247
  * @param idx vector indices to update, size nv
222
248
  * @param v vectors of new values, size nv*d
223
249
  */
224
- virtual void update_vectors (int nv, const idx_t *idx, const float *v);
250
+ virtual void update_vectors(int nv, const idx_t* idx, const float* v);
225
251
 
226
252
  /** Reconstruct a subset of the indexed vectors.
227
253
  *
@@ -243,9 +269,13 @@ struct IndexIVF: Index, Level1Quantizer {
243
269
  *
244
270
  * @param recons reconstructed vectors size (n, k, d)
245
271
  */
246
- void search_and_reconstruct (idx_t n, const float *x, idx_t k,
247
- float *distances, idx_t *labels,
248
- float *recons) const override;
272
+ void search_and_reconstruct(
273
+ idx_t n,
274
+ const float* x,
275
+ idx_t k,
276
+ float* distances,
277
+ idx_t* labels,
278
+ float* recons) const override;
249
279
 
250
280
  /** Reconstruct a vector given the location in terms of (inv list index +
251
281
  * inv list offset) instead of the id.
@@ -254,9 +284,10 @@ struct IndexIVF: Index, Level1Quantizer {
254
284
  * the inv list offset is computed by search_preassigned() with
255
285
  * `store_pairs` set.
256
286
  */
257
- virtual void reconstruct_from_offset (int64_t list_no, int64_t offset,
258
- float* recons) const;
259
-
287
+ virtual void reconstruct_from_offset(
288
+ int64_t list_no,
289
+ int64_t offset,
290
+ float* recons) const;
260
291
 
261
292
  /// Dataset manipulation functions
262
293
 
@@ -265,12 +296,12 @@ struct IndexIVF: Index, Level1Quantizer {
265
296
  /** check that the two indexes are compatible (ie, they are
266
297
  * trained in the same way and have the same
267
298
  * parameters). Otherwise throw. */
268
- void check_compatible_for_merge (const IndexIVF &other) const;
299
+ void check_compatible_for_merge(const IndexIVF& other) const;
269
300
 
270
301
  /** moves the entries from another dataset to self. On output,
271
302
  * other is empty. add_id is added to all moved ids (for
272
303
  * sequential ids, this would be this->ntotal */
273
- virtual void merge_from (IndexIVF &other, idx_t add_id);
304
+ virtual void merge_from(IndexIVF& other, idx_t add_id);
274
305
 
275
306
  /** copy a subset of the entries index to the other index
276
307
  *
@@ -279,34 +310,36 @@ struct IndexIVF: Index, Level1Quantizer {
279
310
  * if subset_type == 2: copies inverted lists such that a1
280
311
  * elements are left before and a2 elements are after
281
312
  */
282
- virtual void copy_subset_to (IndexIVF & other, int subset_type,
283
- idx_t a1, idx_t a2) const;
313
+ virtual void copy_subset_to(
314
+ IndexIVF& other,
315
+ int subset_type,
316
+ idx_t a1,
317
+ idx_t a2) const;
284
318
 
285
319
  ~IndexIVF() override;
286
320
 
287
- size_t get_list_size (size_t list_no) const
288
- { return invlists->list_size(list_no); }
321
+ size_t get_list_size(size_t list_no) const {
322
+ return invlists->list_size(list_no);
323
+ }
289
324
 
290
325
  /** intialize a direct map
291
326
  *
292
327
  * @param new_maintain_direct_map if true, create a direct map,
293
328
  * else clear it
294
329
  */
295
- void make_direct_map (bool new_maintain_direct_map=true);
296
-
297
- void set_direct_map_type (DirectMap::Type type);
330
+ void make_direct_map(bool new_maintain_direct_map = true);
298
331
 
332
+ void set_direct_map_type(DirectMap::Type type);
299
333
 
300
334
  /// replace the inverted lists, old one is deallocated if own_invlists
301
- void replace_invlists (InvertedLists *il, bool own=false);
335
+ void replace_invlists(InvertedLists* il, bool own = false);
302
336
 
303
337
  /* The standalone codec interface (except sa_decode that is specific) */
304
- size_t sa_code_size () const override;
338
+ size_t sa_code_size() const override;
305
339
 
306
- void sa_encode (idx_t n, const float *x,
307
- uint8_t *bytes) const override;
340
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
308
341
 
309
- IndexIVF ();
342
+ IndexIVF();
310
343
  };
311
344
 
312
345
  struct RangeQueryResult;
@@ -316,17 +349,16 @@ struct RangeQueryResult;
316
349
  * distance_to_code and scan_codes can be called in multiple
317
350
  * threads */
318
351
  struct InvertedListScanner {
319
-
320
352
  using idx_t = Index::idx_t;
321
353
 
322
354
  /// from now on we handle this query.
323
- virtual void set_query (const float *query_vector) = 0;
355
+ virtual void set_query(const float* query_vector) = 0;
324
356
 
325
357
  /// following codes come from this inverted list
326
- virtual void set_list (idx_t list_no, float coarse_dis) = 0;
358
+ virtual void set_list(idx_t list_no, float coarse_dis) = 0;
327
359
 
328
360
  /// compute a single query-to-code distance
329
- virtual float distance_to_code (const uint8_t *code) const = 0;
361
+ virtual float distance_to_code(const uint8_t* code) const = 0;
330
362
 
331
363
  /** scan a set of codes, compute distances to current query and
332
364
  * update heap of results if necessary.
@@ -339,45 +371,46 @@ struct InvertedListScanner {
339
371
  * @param k heap size
340
372
  * @return number of heap updates performed
341
373
  */
342
- virtual size_t scan_codes (size_t n,
343
- const uint8_t *codes,
344
- const idx_t *ids,
345
- float *distances, idx_t *labels,
346
- size_t k) const = 0;
374
+ virtual size_t scan_codes(
375
+ size_t n,
376
+ const uint8_t* codes,
377
+ const idx_t* ids,
378
+ float* distances,
379
+ idx_t* labels,
380
+ size_t k) const = 0;
347
381
 
348
382
  /** scan a set of codes, compute distances to current query and
349
383
  * update results if distances are below radius
350
384
  *
351
385
  * (default implementation fails) */
352
- virtual void scan_codes_range (size_t n,
353
- const uint8_t *codes,
354
- const idx_t *ids,
355
- float radius,
356
- RangeQueryResult &result) const;
357
-
358
- virtual ~InvertedListScanner () {}
359
-
386
+ virtual void scan_codes_range(
387
+ size_t n,
388
+ const uint8_t* codes,
389
+ const idx_t* ids,
390
+ float radius,
391
+ RangeQueryResult& result) const;
392
+
393
+ virtual ~InvertedListScanner() {}
360
394
  };
361
395
 
362
-
363
396
  struct IndexIVFStats {
364
- size_t nq; // nb of queries run
365
- size_t nlist; // nb of inverted lists scanned
366
- size_t ndis; // nb of distancs computed
367
- size_t nheap_updates; // nb of times the heap was updated
397
+ size_t nq; // nb of queries run
398
+ size_t nlist; // nb of inverted lists scanned
399
+ size_t ndis; // nb of distancs computed
400
+ size_t nheap_updates; // nb of times the heap was updated
368
401
  double quantization_time; // time spent quantizing vectors (in ms)
369
402
  double search_time; // time spent searching lists (in ms)
370
403
 
371
- IndexIVFStats () {reset (); }
372
- void reset ();
373
- void add (const IndexIVFStats & other);
404
+ IndexIVFStats() {
405
+ reset();
406
+ }
407
+ void reset();
408
+ void add(const IndexIVFStats& other);
374
409
  };
375
410
 
376
411
  // global var that collects them all
377
412
  FAISS_API extern IndexIVFStats indexIVF_stats;
378
413
 
379
-
380
414
  } // namespace faiss
381
415
 
382
-
383
416
  #endif
@@ -9,332 +9,340 @@
9
9
 
10
10
  #include <faiss/IndexIVFFlat.h>
11
11
 
12
+ #include <omp.h>
13
+
12
14
  #include <cinttypes>
13
15
  #include <cstdio>
14
16
 
15
17
  #include <faiss/IndexFlat.h>
16
18
 
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
17
21
  #include <faiss/utils/distances.h>
18
22
  #include <faiss/utils/utils.h>
19
- #include <faiss/impl/FaissAssert.h>
20
- #include <faiss/impl/AuxIndexStructures.h>
21
-
22
23
 
23
24
  namespace faiss {
24
25
 
25
-
26
26
  /*****************************************
27
27
  * IndexIVFFlat implementation
28
28
  ******************************************/
29
29
 
30
- IndexIVFFlat::IndexIVFFlat (Index * quantizer,
31
- size_t d, size_t nlist, MetricType metric):
32
- IndexIVF (quantizer, d, nlist, sizeof(float) * d, metric)
33
- {
30
+ IndexIVFFlat::IndexIVFFlat(
31
+ Index* quantizer,
32
+ size_t d,
33
+ size_t nlist,
34
+ MetricType metric)
35
+ : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
34
36
  code_size = sizeof(float) * d;
35
37
  }
36
38
 
39
+ void IndexIVFFlat::add_core(
40
+ idx_t n,
41
+ const float* x,
42
+ const int64_t* xids,
43
+ const int64_t* coarse_idx)
37
44
 
38
- void IndexIVFFlat::add_with_ids (idx_t n, const float * x, const idx_t *xids)
39
45
  {
40
- add_core (n, x, xids, nullptr);
41
- }
42
-
43
- void IndexIVFFlat::add_core (idx_t n, const float * x, const int64_t *xids,
44
- const int64_t *precomputed_idx)
46
+ FAISS_THROW_IF_NOT(is_trained);
47
+ FAISS_THROW_IF_NOT(coarse_idx);
48
+ assert(invlists);
49
+ direct_map.check_can_add(xids);
45
50
 
46
- {
47
- FAISS_THROW_IF_NOT (is_trained);
48
- assert (invlists);
49
- direct_map.check_can_add (xids);
50
- const int64_t * idx;
51
- ScopeDeleter<int64_t> del;
52
-
53
- if (precomputed_idx) {
54
- idx = precomputed_idx;
55
- } else {
56
- int64_t * idx0 = new int64_t [n];
57
- del.set (idx0);
58
- quantizer->assign (n, x, idx0);
59
- idx = idx0;
60
- }
61
51
  int64_t n_add = 0;
62
- for (size_t i = 0; i < n; i++) {
63
- idx_t id = xids ? xids[i] : ntotal + i;
64
- idx_t list_no = idx [i];
65
- size_t offset;
66
-
67
- if (list_no >= 0) {
68
- const float *xi = x + i * d;
69
- offset = invlists->add_entry (
70
- list_no, id, (const uint8_t*) xi);
71
- n_add++;
72
- } else {
73
- offset = 0;
52
+
53
+ DirectMapAdd dm_adder(direct_map, n, xids);
54
+
55
+ #pragma omp parallel reduction(+ : n_add)
56
+ {
57
+ int nt = omp_get_num_threads();
58
+ int rank = omp_get_thread_num();
59
+
60
+ // each thread takes care of a subset of lists
61
+ for (size_t i = 0; i < n; i++) {
62
+ idx_t list_no = coarse_idx[i];
63
+
64
+ if (list_no >= 0 && list_no % nt == rank) {
65
+ idx_t id = xids ? xids[i] : ntotal + i;
66
+ const float* xi = x + i * d;
67
+ size_t offset =
68
+ invlists->add_entry(list_no, id, (const uint8_t*)xi);
69
+ dm_adder.add(i, list_no, offset);
70
+ n_add++;
71
+ } else if (rank == 0 && list_no == -1) {
72
+ dm_adder.add(i, -1, 0);
73
+ }
74
74
  }
75
- direct_map.add_single_id (id, list_no, offset);
76
75
  }
77
76
 
78
77
  if (verbose) {
79
- printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64 " vectors\n",
80
- n_add, n);
78
+ printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64
79
+ " vectors\n",
80
+ n_add,
81
+ n);
81
82
  }
82
83
  ntotal += n;
83
84
  }
84
85
 
85
- void IndexIVFFlat::encode_vectors(idx_t n, const float* x,
86
- const idx_t * list_nos,
87
- uint8_t * codes,
88
- bool include_listnos) const
89
- {
86
+ void IndexIVFFlat::encode_vectors(
87
+ idx_t n,
88
+ const float* x,
89
+ const idx_t* list_nos,
90
+ uint8_t* codes,
91
+ bool include_listnos) const {
90
92
  if (!include_listnos) {
91
- memcpy (codes, x, code_size * n);
93
+ memcpy(codes, x, code_size * n);
92
94
  } else {
93
- size_t coarse_size = coarse_code_size ();
95
+ size_t coarse_size = coarse_code_size();
94
96
  for (size_t i = 0; i < n; i++) {
95
- int64_t list_no = list_nos [i];
96
- uint8_t *code = codes + i * (code_size + coarse_size);
97
- const float *xi = x + i * d;
97
+ int64_t list_no = list_nos[i];
98
+ uint8_t* code = codes + i * (code_size + coarse_size);
99
+ const float* xi = x + i * d;
98
100
  if (list_no >= 0) {
99
- encode_listno (list_no, code);
100
- memcpy (code + coarse_size, xi, code_size);
101
+ encode_listno(list_no, code);
102
+ memcpy(code + coarse_size, xi, code_size);
101
103
  } else {
102
- memset (code, 0, code_size + coarse_size);
104
+ memset(code, 0, code_size + coarse_size);
103
105
  }
104
-
105
106
  }
106
107
  }
107
108
  }
108
109
 
109
- void IndexIVFFlat::sa_decode (idx_t n, const uint8_t *bytes,
110
- float *x) const
111
- {
112
- size_t coarse_size = coarse_code_size ();
110
+ void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
111
+ size_t coarse_size = coarse_code_size();
113
112
  for (size_t i = 0; i < n; i++) {
114
- const uint8_t *code = bytes + i * (code_size + coarse_size);
115
- float *xi = x + i * d;
116
- memcpy (xi, code + coarse_size, code_size);
113
+ const uint8_t* code = bytes + i * (code_size + coarse_size);
114
+ float* xi = x + i * d;
115
+ memcpy(xi, code + coarse_size, code_size);
117
116
  }
118
117
  }
119
118
 
120
-
121
119
  namespace {
122
120
 
123
-
124
- template<MetricType metric, class C>
125
- struct IVFFlatScanner: InvertedListScanner {
121
+ template <MetricType metric, class C>
122
+ struct IVFFlatScanner : InvertedListScanner {
126
123
  size_t d;
127
124
  bool store_pairs;
128
125
 
129
- IVFFlatScanner(size_t d, bool store_pairs):
130
- d(d), store_pairs(store_pairs) {}
126
+ IVFFlatScanner(size_t d, bool store_pairs)
127
+ : d(d), store_pairs(store_pairs) {}
131
128
 
132
- const float *xi;
133
- void set_query (const float *query) override {
129
+ const float* xi;
130
+ void set_query(const float* query) override {
134
131
  this->xi = query;
135
132
  }
136
133
 
137
134
  idx_t list_no;
138
- void set_list (idx_t list_no, float /* coarse_dis */) override {
135
+ void set_list(idx_t list_no, float /* coarse_dis */) override {
139
136
  this->list_no = list_no;
140
137
  }
141
138
 
142
- float distance_to_code (const uint8_t *code) const override {
143
- const float *yj = (float*)code;
144
- float dis = metric == METRIC_INNER_PRODUCT ?
145
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
139
+ float distance_to_code(const uint8_t* code) const override {
140
+ const float* yj = (float*)code;
141
+ float dis = metric == METRIC_INNER_PRODUCT
142
+ ? fvec_inner_product(xi, yj, d)
143
+ : fvec_L2sqr(xi, yj, d);
146
144
  return dis;
147
145
  }
148
146
 
149
- size_t scan_codes (size_t list_size,
150
- const uint8_t *codes,
151
- const idx_t *ids,
152
- float *simi, idx_t *idxi,
153
- size_t k) const override
154
- {
155
- const float *list_vecs = (const float*)codes;
147
+ size_t scan_codes(
148
+ size_t list_size,
149
+ const uint8_t* codes,
150
+ const idx_t* ids,
151
+ float* simi,
152
+ idx_t* idxi,
153
+ size_t k) const override {
154
+ const float* list_vecs = (const float*)codes;
156
155
  size_t nup = 0;
157
156
  for (size_t j = 0; j < list_size; j++) {
158
- const float * yj = list_vecs + d * j;
159
- float dis = metric == METRIC_INNER_PRODUCT ?
160
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
161
- if (C::cmp (simi[0], dis)) {
162
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
163
- heap_replace_top<C> (k, simi, idxi, dis, id);
157
+ const float* yj = list_vecs + d * j;
158
+ float dis = metric == METRIC_INNER_PRODUCT
159
+ ? fvec_inner_product(xi, yj, d)
160
+ : fvec_L2sqr(xi, yj, d);
161
+ if (C::cmp(simi[0], dis)) {
162
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
163
+ heap_replace_top<C>(k, simi, idxi, dis, id);
164
164
  nup++;
165
165
  }
166
166
  }
167
167
  return nup;
168
168
  }
169
169
 
170
- void scan_codes_range (size_t list_size,
171
- const uint8_t *codes,
172
- const idx_t *ids,
173
- float radius,
174
- RangeQueryResult & res) const override
175
- {
176
- const float *list_vecs = (const float*)codes;
170
+ void scan_codes_range(
171
+ size_t list_size,
172
+ const uint8_t* codes,
173
+ const idx_t* ids,
174
+ float radius,
175
+ RangeQueryResult& res) const override {
176
+ const float* list_vecs = (const float*)codes;
177
177
  for (size_t j = 0; j < list_size; j++) {
178
- const float * yj = list_vecs + d * j;
179
- float dis = metric == METRIC_INNER_PRODUCT ?
180
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
181
- if (C::cmp (radius, dis)) {
182
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
183
- res.add (dis, id);
178
+ const float* yj = list_vecs + d * j;
179
+ float dis = metric == METRIC_INNER_PRODUCT
180
+ ? fvec_inner_product(xi, yj, d)
181
+ : fvec_L2sqr(xi, yj, d);
182
+ if (C::cmp(radius, dis)) {
183
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
184
+ res.add(dis, id);
184
185
  }
185
186
  }
186
187
  }
187
-
188
-
189
188
  };
190
189
 
191
-
192
190
  } // anonymous namespace
193
191
 
194
-
195
-
196
- InvertedListScanner* IndexIVFFlat::get_InvertedListScanner
197
- (bool store_pairs) const
198
- {
192
+ InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
193
+ bool store_pairs) const {
199
194
  if (metric_type == METRIC_INNER_PRODUCT) {
200
- return new IVFFlatScanner<
201
- METRIC_INNER_PRODUCT, CMin<float, int64_t> > (d, store_pairs);
195
+ return new IVFFlatScanner<METRIC_INNER_PRODUCT, CMin<float, int64_t>>(
196
+ d, store_pairs);
202
197
  } else if (metric_type == METRIC_L2) {
203
- return new IVFFlatScanner<
204
- METRIC_L2, CMax<float, int64_t> >(d, store_pairs);
198
+ return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>>(
199
+ d, store_pairs);
205
200
  } else {
206
201
  FAISS_THROW_MSG("metric type not supported");
207
202
  }
208
203
  return nullptr;
209
204
  }
210
205
 
211
-
212
-
213
-
214
- void IndexIVFFlat::reconstruct_from_offset (int64_t list_no, int64_t offset,
215
- float* recons) const
216
- {
217
- memcpy (recons, invlists->get_single_code (list_no, offset), code_size);
206
+ void IndexIVFFlat::reconstruct_from_offset(
207
+ int64_t list_no,
208
+ int64_t offset,
209
+ float* recons) const {
210
+ memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
218
211
  }
219
212
 
220
213
  /*****************************************
221
214
  * IndexIVFFlatDedup implementation
222
215
  ******************************************/
223
216
 
224
- IndexIVFFlatDedup::IndexIVFFlatDedup (
225
- Index * quantizer, size_t d, size_t nlist_,
226
- MetricType metric_type):
227
- IndexIVFFlat (quantizer, d, nlist_, metric_type)
228
- {}
229
-
217
+ IndexIVFFlatDedup::IndexIVFFlatDedup(
218
+ Index* quantizer,
219
+ size_t d,
220
+ size_t nlist_,
221
+ MetricType metric_type)
222
+ : IndexIVFFlat(quantizer, d, nlist_, metric_type) {}
230
223
 
231
- void IndexIVFFlatDedup::train(idx_t n, const float* x)
232
- {
224
+ void IndexIVFFlatDedup::train(idx_t n, const float* x) {
233
225
  std::unordered_map<uint64_t, idx_t> map;
234
- float * x2 = new float [n * d];
235
- ScopeDeleter<float> del (x2);
226
+ float* x2 = new float[n * d];
227
+ ScopeDeleter<float> del(x2);
236
228
 
237
229
  int64_t n2 = 0;
238
230
  for (int64_t i = 0; i < n; i++) {
239
- uint64_t hash = hash_bytes((uint8_t *)(x + i * d), code_size);
231
+ uint64_t hash = hash_bytes((uint8_t*)(x + i * d), code_size);
240
232
  if (map.count(hash) &&
241
- !memcmp (x2 + map[hash] * d, x + i * d, code_size)) {
233
+ !memcmp(x2 + map[hash] * d, x + i * d, code_size)) {
242
234
  // is duplicate, skip
243
235
  } else {
244
- map [hash] = n2;
245
- memcpy (x2 + n2 * d, x + i * d, code_size);
246
- n2 ++;
236
+ map[hash] = n2;
237
+ memcpy(x2 + n2 * d, x + i * d, code_size);
238
+ n2++;
247
239
  }
248
240
  }
249
241
  if (verbose) {
250
- printf ("IndexIVFFlatDedup::train: train on %" PRId64 " points after dedup "
251
- "(was %" PRId64 " points)\n", n2, n);
242
+ printf("IndexIVFFlatDedup::train: train on %" PRId64
243
+ " points after dedup "
244
+ "(was %" PRId64 " points)\n",
245
+ n2,
246
+ n);
252
247
  }
253
- IndexIVFFlat::train (n2, x2);
248
+ IndexIVFFlat::train(n2, x2);
254
249
  }
255
250
 
251
+ void IndexIVFFlatDedup::add_with_ids(
252
+ idx_t na,
253
+ const float* x,
254
+ const idx_t* xids) {
255
+ FAISS_THROW_IF_NOT(is_trained);
256
+ assert(invlists);
257
+ FAISS_THROW_IF_NOT_MSG(
258
+ direct_map.no(), "IVFFlatDedup not implemented with direct_map");
259
+ int64_t* idx = new int64_t[na];
260
+ ScopeDeleter<int64_t> del(idx);
261
+ quantizer->assign(na, x, idx);
256
262
 
263
+ int64_t n_add = 0, n_dup = 0;
257
264
 
258
- void IndexIVFFlatDedup::add_with_ids(
259
- idx_t na, const float* x, const idx_t* xids)
260
- {
265
+ #pragma omp parallel reduction(+ : n_add, n_dup)
266
+ {
267
+ int nt = omp_get_num_threads();
268
+ int rank = omp_get_thread_num();
261
269
 
262
- FAISS_THROW_IF_NOT (is_trained);
263
- assert (invlists);
264
- FAISS_THROW_IF_NOT_MSG (direct_map.no(),
265
- "IVFFlatDedup not implemented with direct_map");
266
- int64_t * idx = new int64_t [na];
267
- ScopeDeleter<int64_t> del (idx);
268
- quantizer->assign (na, x, idx);
270
+ // each thread takes care of a subset of lists
271
+ for (size_t i = 0; i < na; i++) {
272
+ int64_t list_no = idx[i];
269
273
 
270
- int64_t n_add = 0, n_dup = 0;
271
- // TODO make a omp loop with this
272
- for (size_t i = 0; i < na; i++) {
273
- idx_t id = xids ? xids[i] : ntotal + i;
274
- int64_t list_no = idx [i];
274
+ if (list_no < 0 || list_no % nt != rank) {
275
+ continue;
276
+ }
275
277
 
276
- if (list_no < 0) {
277
- continue;
278
- }
279
- const float *xi = x + i * d;
278
+ idx_t id = xids ? xids[i] : ntotal + i;
279
+ const float* xi = x + i * d;
280
280
 
281
- // search if there is already an entry with that id
282
- InvertedLists::ScopedCodes codes (invlists, list_no);
281
+ // search if there is already an entry with that id
282
+ InvertedLists::ScopedCodes codes(invlists, list_no);
283
283
 
284
- int64_t n = invlists->list_size (list_no);
285
- int64_t offset = -1;
286
- for (int64_t o = 0; o < n; o++) {
287
- if (!memcmp (codes.get() + o * code_size,
288
- xi, code_size)) {
289
- offset = o;
290
- break;
284
+ int64_t n = invlists->list_size(list_no);
285
+ int64_t offset = -1;
286
+ for (int64_t o = 0; o < n; o++) {
287
+ if (!memcmp(codes.get() + o * code_size, xi, code_size)) {
288
+ offset = o;
289
+ break;
290
+ }
291
291
  }
292
- }
293
292
 
294
- if (offset == -1) { // not found
295
- invlists->add_entry (list_no, id, (const uint8_t*) xi);
296
- } else {
297
- // mark equivalence
298
- idx_t id2 = invlists->get_single_id (list_no, offset);
299
- std::pair<idx_t, idx_t> pair (id2, id);
300
- instances.insert (pair);
301
- n_dup ++;
293
+ if (offset == -1) { // not found
294
+ invlists->add_entry(list_no, id, (const uint8_t*)xi);
295
+ } else {
296
+ // mark equivalence
297
+ idx_t id2 = invlists->get_single_id(list_no, offset);
298
+ std::pair<idx_t, idx_t> pair(id2, id);
299
+
300
+ #pragma omp critical
301
+ // executed by one thread at a time
302
+ instances.insert(pair);
303
+
304
+ n_dup++;
305
+ }
306
+ n_add++;
302
307
  }
303
- n_add++;
304
308
  }
305
309
  if (verbose) {
306
- printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64 " vectors"
310
+ printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64
311
+ " vectors"
307
312
  " (out of which %" PRId64 " are duplicates)\n",
308
- n_add, na, n_dup);
313
+ n_add,
314
+ na,
315
+ n_dup);
309
316
  }
310
317
  ntotal += n_add;
311
318
  }
312
319
 
313
- void IndexIVFFlatDedup::search_preassigned (
314
- idx_t n, const float *x, idx_t k,
315
- const idx_t *assign,
316
- const float *centroid_dis,
317
- float *distances, idx_t *labels,
318
- bool store_pairs,
319
- const IVFSearchParameters *params,
320
- IndexIVFStats *stats) const
321
- {
322
- FAISS_THROW_IF_NOT_MSG (
323
- !store_pairs, "store_pairs not supported in IVFDedup");
324
-
325
- IndexIVFFlat::search_preassigned (n, x, k, assign, centroid_dis,
326
- distances, labels, false,
327
- params);
328
-
329
- std::vector <idx_t> labels2 (k);
330
- std::vector <float> dis2 (k);
320
+ void IndexIVFFlatDedup::search_preassigned(
321
+ idx_t n,
322
+ const float* x,
323
+ idx_t k,
324
+ const idx_t* assign,
325
+ const float* centroid_dis,
326
+ float* distances,
327
+ idx_t* labels,
328
+ bool store_pairs,
329
+ const IVFSearchParameters* params,
330
+ IndexIVFStats* stats) const {
331
+ FAISS_THROW_IF_NOT_MSG(
332
+ !store_pairs, "store_pairs not supported in IVFDedup");
333
+
334
+ IndexIVFFlat::search_preassigned(
335
+ n, x, k, assign, centroid_dis, distances, labels, false, params);
336
+
337
+ std::vector<idx_t> labels2(k);
338
+ std::vector<float> dis2(k);
331
339
 
332
340
  for (int64_t i = 0; i < n; i++) {
333
- idx_t *labels1 = labels + i * k;
334
- float *dis1 = distances + i * k;
341
+ idx_t* labels1 = labels + i * k;
342
+ float* dis1 = distances + i * k;
335
343
  int64_t j = 0;
336
344
  for (; j < k; j++) {
337
- if (instances.find (labels1[j]) != instances.end ()) {
345
+ if (instances.find(labels1[j]) != instances.end()) {
338
346
  // a duplicate: special handling
339
347
  break;
340
348
  }
@@ -344,11 +352,11 @@ void IndexIVFFlatDedup::search_preassigned (
344
352
  int64_t j0 = j;
345
353
  int64_t rp = j;
346
354
  while (j < k) {
347
- auto range = instances.equal_range (labels1[rp]);
355
+ auto range = instances.equal_range(labels1[rp]);
348
356
  float dis = dis1[rp];
349
357
  labels2[j] = labels1[rp];
350
358
  dis2[j] = dis;
351
- j ++;
359
+ j++;
352
360
  for (auto it = range.first; j < k && it != range.second; ++it) {
353
361
  labels2[j] = it->second;
354
362
  dis2[j] = dis;
@@ -356,21 +364,18 @@ void IndexIVFFlatDedup::search_preassigned (
356
364
  }
357
365
  rp++;
358
366
  }
359
- memcpy (labels1 + j0, labels2.data() + j0,
360
- sizeof(labels1[0]) * (k - j0));
361
- memcpy (dis1 + j0, dis2.data() + j0,
362
- sizeof(dis2[0]) * (k - j0));
367
+ memcpy(labels1 + j0,
368
+ labels2.data() + j0,
369
+ sizeof(labels1[0]) * (k - j0));
370
+ memcpy(dis1 + j0, dis2.data() + j0, sizeof(dis2[0]) * (k - j0));
363
371
  }
364
372
  }
365
-
366
373
  }
367
374
 
368
-
369
- size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
370
- {
375
+ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel) {
371
376
  std::unordered_map<idx_t, idx_t> replace;
372
- std::vector<std::pair<idx_t, idx_t> > toadd;
373
- for (auto it = instances.begin(); it != instances.end(); ) {
377
+ std::vector<std::pair<idx_t, idx_t>> toadd;
378
+ for (auto it = instances.begin(); it != instances.end();) {
374
379
  if (sel.is_member(it->first)) {
375
380
  // then we erase this entry
376
381
  if (!sel.is_member(it->second)) {
@@ -378,8 +383,8 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
378
383
  if (replace.count(it->first) == 0) {
379
384
  replace[it->first] = it->second;
380
385
  } else { // remember we should add an element
381
- std::pair<idx_t, idx_t> new_entry (
382
- replace[it->first], it->second);
386
+ std::pair<idx_t, idx_t> new_entry(
387
+ replace[it->first], it->second);
383
388
  toadd.push_back(new_entry);
384
389
  }
385
390
  }
@@ -393,32 +398,34 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
393
398
  }
394
399
  }
395
400
 
396
- instances.insert (toadd.begin(), toadd.end());
401
+ instances.insert(toadd.begin(), toadd.end());
397
402
 
398
403
  // mostly copied from IndexIVF.cpp
399
404
 
400
- FAISS_THROW_IF_NOT_MSG (direct_map.no(),
401
- "direct map remove not implemented");
405
+ FAISS_THROW_IF_NOT_MSG(
406
+ direct_map.no(), "direct map remove not implemented");
402
407
 
403
408
  std::vector<int64_t> toremove(nlist);
404
409
 
405
410
  #pragma omp parallel for
406
411
  for (int64_t i = 0; i < nlist; i++) {
407
- int64_t l0 = invlists->list_size (i), l = l0, j = 0;
408
- InvertedLists::ScopedIds idsi (invlists, i);
412
+ int64_t l0 = invlists->list_size(i), l = l0, j = 0;
413
+ InvertedLists::ScopedIds idsi(invlists, i);
409
414
  while (j < l) {
410
- if (sel.is_member (idsi[j])) {
415
+ if (sel.is_member(idsi[j])) {
411
416
  if (replace.count(idsi[j]) == 0) {
412
417
  l--;
413
- invlists->update_entry (
414
- i, j,
415
- invlists->get_single_id (i, l),
416
- InvertedLists::ScopedCodes (invlists, i, l).get());
418
+ invlists->update_entry(
419
+ i,
420
+ j,
421
+ invlists->get_single_id(i, l),
422
+ InvertedLists::ScopedCodes(invlists, i, l).get());
417
423
  } else {
418
- invlists->update_entry (
419
- i, j,
420
- replace[idsi[j]],
421
- InvertedLists::ScopedCodes (invlists, i, j).get());
424
+ invlists->update_entry(
425
+ i,
426
+ j,
427
+ replace[idsi[j]],
428
+ InvertedLists::ScopedCodes(invlists, i, j).get());
422
429
  j++;
423
430
  }
424
431
  } else {
@@ -432,37 +439,28 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
432
439
  for (int64_t i = 0; i < nlist; i++) {
433
440
  if (toremove[i] > 0) {
434
441
  nremove += toremove[i];
435
- invlists->resize(
436
- i, invlists->list_size(i) - toremove[i]);
442
+ invlists->resize(i, invlists->list_size(i) - toremove[i]);
437
443
  }
438
444
  }
439
445
  ntotal -= nremove;
440
446
  return nremove;
441
447
  }
442
448
 
443
-
444
449
  void IndexIVFFlatDedup::range_search(
445
- idx_t ,
446
- const float* ,
447
- float ,
448
- RangeSearchResult* ) const
449
- {
450
- FAISS_THROW_MSG ("not implemented");
450
+ idx_t,
451
+ const float*,
452
+ float,
453
+ RangeSearchResult*) const {
454
+ FAISS_THROW_MSG("not implemented");
451
455
  }
452
456
 
453
- void IndexIVFFlatDedup::update_vectors (int , const idx_t *, const float *)
454
- {
455
- FAISS_THROW_MSG ("not implemented");
457
+ void IndexIVFFlatDedup::update_vectors(int, const idx_t*, const float*) {
458
+ FAISS_THROW_MSG("not implemented");
456
459
  }
457
460
 
458
-
459
- void IndexIVFFlatDedup::reconstruct_from_offset (
460
- int64_t , int64_t , float* ) const
461
- {
462
- FAISS_THROW_MSG ("not implemented");
461
+ void IndexIVFFlatDedup::reconstruct_from_offset(int64_t, int64_t, float*)
462
+ const {
463
+ FAISS_THROW_MSG("not implemented");
463
464
  }
464
465
 
465
-
466
-
467
-
468
466
  } // namespace faiss