faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -10,30 +10,27 @@
10
10
  #ifndef FAISS_INDEX_IVF_H
11
11
  #define FAISS_INDEX_IVF_H
12
12
 
13
-
14
- #include <vector>
15
- #include <unordered_map>
16
13
  #include <stdint.h>
14
+ #include <unordered_map>
15
+ #include <vector>
17
16
 
18
- #include <faiss/Index.h>
19
- #include <faiss/invlists/InvertedLists.h>
20
- #include <faiss/invlists/DirectMap.h>
21
17
  #include <faiss/Clustering.h>
18
+ #include <faiss/Index.h>
22
19
  #include <faiss/impl/platform_macros.h>
20
+ #include <faiss/invlists/DirectMap.h>
21
+ #include <faiss/invlists/InvertedLists.h>
23
22
  #include <faiss/utils/Heap.h>
24
23
 
25
-
26
24
  namespace faiss {
27
25
 
28
-
29
26
  /** Encapsulates a quantizer object for the IndexIVF
30
27
  *
31
28
  * The class isolates the fields that are independent of the storage
32
29
  * of the lists (especially training)
33
30
  */
34
31
  struct Level1Quantizer {
35
- Index * quantizer; ///< quantizer that maps vectors to inverted lists
36
- size_t nlist; ///< number of possible key values
32
+ Index* quantizer; ///< quantizer that maps vectors to inverted lists
33
+ size_t nlist; ///< number of possible key values
37
34
 
38
35
  /**
39
36
  * = 0: use the quantizer as index in a kmeans training
@@ -41,40 +38,37 @@ struct Level1Quantizer {
41
38
  * = 2: kmeans training on a flat index + add the centroids to the quantizer
42
39
  */
43
40
  char quantizer_trains_alone;
44
- bool own_fields; ///< whether object owns the quantizer
41
+ bool own_fields; ///< whether object owns the quantizer
45
42
 
46
43
  ClusteringParameters cp; ///< to override default clustering params
47
- Index *clustering_index; ///< to override index used during clustering
44
+ Index* clustering_index; ///< to override index used during clustering
48
45
 
49
46
  /// Trains the quantizer and calls train_residual to train sub-quantizers
50
- void train_q1 (size_t n, const float *x, bool verbose,
51
- MetricType metric_type);
52
-
47
+ void train_q1(
48
+ size_t n,
49
+ const float* x,
50
+ bool verbose,
51
+ MetricType metric_type);
53
52
 
54
53
  /// compute the number of bytes required to store list ids
55
- size_t coarse_code_size () const;
56
- void encode_listno (Index::idx_t list_no, uint8_t *code) const;
57
- Index::idx_t decode_listno (const uint8_t *code) const;
54
+ size_t coarse_code_size() const;
55
+ void encode_listno(Index::idx_t list_no, uint8_t* code) const;
56
+ Index::idx_t decode_listno(const uint8_t* code) const;
58
57
 
59
- Level1Quantizer (Index * quantizer, size_t nlist);
58
+ Level1Quantizer(Index* quantizer, size_t nlist);
60
59
 
61
- Level1Quantizer ();
62
-
63
- ~Level1Quantizer ();
60
+ Level1Quantizer();
64
61
 
62
+ ~Level1Quantizer();
65
63
  };
66
64
 
67
-
68
-
69
65
  struct IVFSearchParameters {
70
- size_t nprobe; ///< number of probes at query time
71
- size_t max_codes; ///< max nb of codes to visit to do a query
72
- IVFSearchParameters(): nprobe(1), max_codes(0) {}
73
- virtual ~IVFSearchParameters () {}
66
+ size_t nprobe; ///< number of probes at query time
67
+ size_t max_codes; ///< max nb of codes to visit to do a query
68
+ IVFSearchParameters() : nprobe(1), max_codes(0) {}
69
+ virtual ~IVFSearchParameters() {}
74
70
  };
75
71
 
76
-
77
-
78
72
  struct InvertedListScanner;
79
73
  struct IndexIVFStats;
80
74
 
@@ -98,15 +92,15 @@ struct IndexIVFStats;
98
92
  * Sub-classes implement a post-filtering of the index that refines
99
93
  * the distance estimation from the query to databse vectors.
100
94
  */
101
- struct IndexIVF: Index, Level1Quantizer {
95
+ struct IndexIVF : Index, Level1Quantizer {
102
96
  /// Access to the actual data
103
- InvertedLists *invlists;
97
+ InvertedLists* invlists;
104
98
  bool own_invlists;
105
99
 
106
- size_t code_size; ///< code size per vector in bytes
100
+ size_t code_size; ///< code size per vector in bytes
107
101
 
108
- size_t nprobe; ///< number of probes at query time
109
- size_t max_codes; ///< max nb of codes to visit to do a query
102
+ size_t nprobe; ///< number of probes at query time
103
+ size_t max_codes; ///< max nb of codes to visit to do a query
110
104
 
111
105
  /** Parallel mode determines how queries are parallelized with OpenMP
112
106
  *
@@ -130,9 +124,12 @@ struct IndexIVF: Index, Level1Quantizer {
130
124
  * identifier. The pointer is borrowed: the quantizer should not
131
125
  * be deleted while the IndexIVF is in use.
132
126
  */
133
- IndexIVF (Index * quantizer, size_t d,
134
- size_t nlist, size_t code_size,
135
- MetricType metric = METRIC_L2);
127
+ IndexIVF(
128
+ Index* quantizer,
129
+ size_t d,
130
+ size_t nlist,
131
+ size_t code_size,
132
+ MetricType metric = METRIC_L2);
136
133
 
137
134
  void reset() override;
138
135
 
@@ -145,6 +142,19 @@ struct IndexIVF: Index, Level1Quantizer {
145
142
  /// default implementation that calls encode_vectors
146
143
  void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
147
144
 
145
+ /** Implementation of vector addition where the vector assignments are
146
+ * predefined. The default implementation hands over the code extraction to
147
+ * encode_vectors.
148
+ *
149
+ * @param precomputed_idx quantization indices for the input vectors
150
+ * (size n)
151
+ */
152
+ virtual void add_core(
153
+ idx_t n,
154
+ const float* x,
155
+ const idx_t* xids,
156
+ const idx_t* precomputed_idx);
157
+
148
158
  /** Encodes a set of vectors as they would appear in the inverted lists
149
159
  *
150
160
  * @param list_nos inverted list ids as returned by the
@@ -154,14 +164,16 @@ struct IndexIVF: Index, Level1Quantizer {
154
164
  * include the list ids in the code (in this case add
155
165
  * ceil(log8(nlist)) to the code size)
156
166
  */
157
- virtual void encode_vectors(idx_t n, const float* x,
158
- const idx_t *list_nos,
159
- uint8_t * codes,
160
- bool include_listno = false) const = 0;
167
+ virtual void encode_vectors(
168
+ idx_t n,
169
+ const float* x,
170
+ const idx_t* list_nos,
171
+ uint8_t* codes,
172
+ bool include_listno = false) const = 0;
161
173
 
162
174
  /// Sub-classes that encode the residuals can train their encoders here
163
175
  /// does nothing by default
164
- virtual void train_residual (idx_t n, const float *x);
176
+ virtual void train_residual(idx_t n, const float* x);
165
177
 
166
178
  /** search a set of vectors, that are pre-quantized by the IVF
167
179
  * quantizer. Fill in the corresponding heaps with the query
@@ -182,36 +194,50 @@ struct IndexIVF: Index, Level1Quantizer {
182
194
  * @param params used to override the object's search parameters
183
195
  * @param stats search stats to be updated (can be null)
184
196
  */
185
- virtual void search_preassigned (
186
- idx_t n, const float *x, idx_t k,
187
- const idx_t *assign, const float *centroid_dis,
188
- float *distances, idx_t *labels,
197
+ virtual void search_preassigned(
198
+ idx_t n,
199
+ const float* x,
200
+ idx_t k,
201
+ const idx_t* assign,
202
+ const float* centroid_dis,
203
+ float* distances,
204
+ idx_t* labels,
189
205
  bool store_pairs,
190
- const IVFSearchParameters *params=nullptr,
191
- IndexIVFStats *stats=nullptr
192
- ) const;
206
+ const IVFSearchParameters* params = nullptr,
207
+ IndexIVFStats* stats = nullptr) const;
193
208
 
194
209
  /** assign the vectors, then call search_preassign */
195
- void search (idx_t n, const float *x, idx_t k,
196
- float *distances, idx_t *labels) const override;
197
-
198
- void range_search (idx_t n, const float* x, float radius,
199
- RangeSearchResult* result) const override;
210
+ void search(
211
+ idx_t n,
212
+ const float* x,
213
+ idx_t k,
214
+ float* distances,
215
+ idx_t* labels) const override;
216
+
217
+ void range_search(
218
+ idx_t n,
219
+ const float* x,
220
+ float radius,
221
+ RangeSearchResult* result) const override;
200
222
 
201
223
  void range_search_preassigned(
202
- idx_t nx, const float *x, float radius,
203
- const idx_t *keys, const float *coarse_dis,
204
- RangeSearchResult *result,
205
- bool store_pairs=false,
206
- const IVFSearchParameters *params=nullptr,
207
- IndexIVFStats *stats=nullptr) const;
224
+ idx_t nx,
225
+ const float* x,
226
+ float radius,
227
+ const idx_t* keys,
228
+ const float* coarse_dis,
229
+ RangeSearchResult* result,
230
+ bool store_pairs = false,
231
+ const IVFSearchParameters* params = nullptr,
232
+ IndexIVFStats* stats = nullptr) const;
208
233
 
209
234
  /// get a scanner for this index (store_pairs means ignore labels)
210
- virtual InvertedListScanner *get_InvertedListScanner (
211
- bool store_pairs=false) const;
235
+ virtual InvertedListScanner* get_InvertedListScanner(
236
+ bool store_pairs = false) const;
212
237
 
213
- /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2 */
214
- void reconstruct (idx_t key, float* recons) const override;
238
+ /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
239
+ */
240
+ void reconstruct(idx_t key, float* recons) const override;
215
241
 
216
242
  /** Update a subset of vectors.
217
243
  *
@@ -221,7 +247,7 @@ struct IndexIVF: Index, Level1Quantizer {
221
247
  * @param idx vector indices to update, size nv
222
248
  * @param v vectors of new values, size nv*d
223
249
  */
224
- virtual void update_vectors (int nv, const idx_t *idx, const float *v);
250
+ virtual void update_vectors(int nv, const idx_t* idx, const float* v);
225
251
 
226
252
  /** Reconstruct a subset of the indexed vectors.
227
253
  *
@@ -243,9 +269,13 @@ struct IndexIVF: Index, Level1Quantizer {
243
269
  *
244
270
  * @param recons reconstructed vectors size (n, k, d)
245
271
  */
246
- void search_and_reconstruct (idx_t n, const float *x, idx_t k,
247
- float *distances, idx_t *labels,
248
- float *recons) const override;
272
+ void search_and_reconstruct(
273
+ idx_t n,
274
+ const float* x,
275
+ idx_t k,
276
+ float* distances,
277
+ idx_t* labels,
278
+ float* recons) const override;
249
279
 
250
280
  /** Reconstruct a vector given the location in terms of (inv list index +
251
281
  * inv list offset) instead of the id.
@@ -254,9 +284,10 @@ struct IndexIVF: Index, Level1Quantizer {
254
284
  * the inv list offset is computed by search_preassigned() with
255
285
  * `store_pairs` set.
256
286
  */
257
- virtual void reconstruct_from_offset (int64_t list_no, int64_t offset,
258
- float* recons) const;
259
-
287
+ virtual void reconstruct_from_offset(
288
+ int64_t list_no,
289
+ int64_t offset,
290
+ float* recons) const;
260
291
 
261
292
  /// Dataset manipulation functions
262
293
 
@@ -265,12 +296,12 @@ struct IndexIVF: Index, Level1Quantizer {
265
296
  /** check that the two indexes are compatible (ie, they are
266
297
  * trained in the same way and have the same
267
298
  * parameters). Otherwise throw. */
268
- void check_compatible_for_merge (const IndexIVF &other) const;
299
+ void check_compatible_for_merge(const IndexIVF& other) const;
269
300
 
270
301
  /** moves the entries from another dataset to self. On output,
271
302
  * other is empty. add_id is added to all moved ids (for
272
303
  * sequential ids, this would be this->ntotal */
273
- virtual void merge_from (IndexIVF &other, idx_t add_id);
304
+ virtual void merge_from(IndexIVF& other, idx_t add_id);
274
305
 
275
306
  /** copy a subset of the entries index to the other index
276
307
  *
@@ -279,34 +310,36 @@ struct IndexIVF: Index, Level1Quantizer {
279
310
  * if subset_type == 2: copies inverted lists such that a1
280
311
  * elements are left before and a2 elements are after
281
312
  */
282
- virtual void copy_subset_to (IndexIVF & other, int subset_type,
283
- idx_t a1, idx_t a2) const;
313
+ virtual void copy_subset_to(
314
+ IndexIVF& other,
315
+ int subset_type,
316
+ idx_t a1,
317
+ idx_t a2) const;
284
318
 
285
319
  ~IndexIVF() override;
286
320
 
287
- size_t get_list_size (size_t list_no) const
288
- { return invlists->list_size(list_no); }
321
+ size_t get_list_size(size_t list_no) const {
322
+ return invlists->list_size(list_no);
323
+ }
289
324
 
290
325
  /** intialize a direct map
291
326
  *
292
327
  * @param new_maintain_direct_map if true, create a direct map,
293
328
  * else clear it
294
329
  */
295
- void make_direct_map (bool new_maintain_direct_map=true);
296
-
297
- void set_direct_map_type (DirectMap::Type type);
330
+ void make_direct_map(bool new_maintain_direct_map = true);
298
331
 
332
+ void set_direct_map_type(DirectMap::Type type);
299
333
 
300
334
  /// replace the inverted lists, old one is deallocated if own_invlists
301
- void replace_invlists (InvertedLists *il, bool own=false);
335
+ void replace_invlists(InvertedLists* il, bool own = false);
302
336
 
303
337
  /* The standalone codec interface (except sa_decode that is specific) */
304
- size_t sa_code_size () const override;
338
+ size_t sa_code_size() const override;
305
339
 
306
- void sa_encode (idx_t n, const float *x,
307
- uint8_t *bytes) const override;
340
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
308
341
 
309
- IndexIVF ();
342
+ IndexIVF();
310
343
  };
311
344
 
312
345
  struct RangeQueryResult;
@@ -316,17 +349,16 @@ struct RangeQueryResult;
316
349
  * distance_to_code and scan_codes can be called in multiple
317
350
  * threads */
318
351
  struct InvertedListScanner {
319
-
320
352
  using idx_t = Index::idx_t;
321
353
 
322
354
  /// from now on we handle this query.
323
- virtual void set_query (const float *query_vector) = 0;
355
+ virtual void set_query(const float* query_vector) = 0;
324
356
 
325
357
  /// following codes come from this inverted list
326
- virtual void set_list (idx_t list_no, float coarse_dis) = 0;
358
+ virtual void set_list(idx_t list_no, float coarse_dis) = 0;
327
359
 
328
360
  /// compute a single query-to-code distance
329
- virtual float distance_to_code (const uint8_t *code) const = 0;
361
+ virtual float distance_to_code(const uint8_t* code) const = 0;
330
362
 
331
363
  /** scan a set of codes, compute distances to current query and
332
364
  * update heap of results if necessary.
@@ -339,45 +371,46 @@ struct InvertedListScanner {
339
371
  * @param k heap size
340
372
  * @return number of heap updates performed
341
373
  */
342
- virtual size_t scan_codes (size_t n,
343
- const uint8_t *codes,
344
- const idx_t *ids,
345
- float *distances, idx_t *labels,
346
- size_t k) const = 0;
374
+ virtual size_t scan_codes(
375
+ size_t n,
376
+ const uint8_t* codes,
377
+ const idx_t* ids,
378
+ float* distances,
379
+ idx_t* labels,
380
+ size_t k) const = 0;
347
381
 
348
382
  /** scan a set of codes, compute distances to current query and
349
383
  * update results if distances are below radius
350
384
  *
351
385
  * (default implementation fails) */
352
- virtual void scan_codes_range (size_t n,
353
- const uint8_t *codes,
354
- const idx_t *ids,
355
- float radius,
356
- RangeQueryResult &result) const;
357
-
358
- virtual ~InvertedListScanner () {}
359
-
386
+ virtual void scan_codes_range(
387
+ size_t n,
388
+ const uint8_t* codes,
389
+ const idx_t* ids,
390
+ float radius,
391
+ RangeQueryResult& result) const;
392
+
393
+ virtual ~InvertedListScanner() {}
360
394
  };
361
395
 
362
-
363
396
  struct IndexIVFStats {
364
- size_t nq; // nb of queries run
365
- size_t nlist; // nb of inverted lists scanned
366
- size_t ndis; // nb of distancs computed
367
- size_t nheap_updates; // nb of times the heap was updated
397
+ size_t nq; // nb of queries run
398
+ size_t nlist; // nb of inverted lists scanned
399
+ size_t ndis; // nb of distancs computed
400
+ size_t nheap_updates; // nb of times the heap was updated
368
401
  double quantization_time; // time spent quantizing vectors (in ms)
369
402
  double search_time; // time spent searching lists (in ms)
370
403
 
371
- IndexIVFStats () {reset (); }
372
- void reset ();
373
- void add (const IndexIVFStats & other);
404
+ IndexIVFStats() {
405
+ reset();
406
+ }
407
+ void reset();
408
+ void add(const IndexIVFStats& other);
374
409
  };
375
410
 
376
411
  // global var that collects them all
377
412
  FAISS_API extern IndexIVFStats indexIVF_stats;
378
413
 
379
-
380
414
  } // namespace faiss
381
415
 
382
-
383
416
  #endif
@@ -9,332 +9,340 @@
9
9
 
10
10
  #include <faiss/IndexIVFFlat.h>
11
11
 
12
+ #include <omp.h>
13
+
12
14
  #include <cinttypes>
13
15
  #include <cstdio>
14
16
 
15
17
  #include <faiss/IndexFlat.h>
16
18
 
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
17
21
  #include <faiss/utils/distances.h>
18
22
  #include <faiss/utils/utils.h>
19
- #include <faiss/impl/FaissAssert.h>
20
- #include <faiss/impl/AuxIndexStructures.h>
21
-
22
23
 
23
24
  namespace faiss {
24
25
 
25
-
26
26
  /*****************************************
27
27
  * IndexIVFFlat implementation
28
28
  ******************************************/
29
29
 
30
- IndexIVFFlat::IndexIVFFlat (Index * quantizer,
31
- size_t d, size_t nlist, MetricType metric):
32
- IndexIVF (quantizer, d, nlist, sizeof(float) * d, metric)
33
- {
30
+ IndexIVFFlat::IndexIVFFlat(
31
+ Index* quantizer,
32
+ size_t d,
33
+ size_t nlist,
34
+ MetricType metric)
35
+ : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
34
36
  code_size = sizeof(float) * d;
35
37
  }
36
38
 
39
+ void IndexIVFFlat::add_core(
40
+ idx_t n,
41
+ const float* x,
42
+ const int64_t* xids,
43
+ const int64_t* coarse_idx)
37
44
 
38
- void IndexIVFFlat::add_with_ids (idx_t n, const float * x, const idx_t *xids)
39
45
  {
40
- add_core (n, x, xids, nullptr);
41
- }
42
-
43
- void IndexIVFFlat::add_core (idx_t n, const float * x, const int64_t *xids,
44
- const int64_t *precomputed_idx)
46
+ FAISS_THROW_IF_NOT(is_trained);
47
+ FAISS_THROW_IF_NOT(coarse_idx);
48
+ assert(invlists);
49
+ direct_map.check_can_add(xids);
45
50
 
46
- {
47
- FAISS_THROW_IF_NOT (is_trained);
48
- assert (invlists);
49
- direct_map.check_can_add (xids);
50
- const int64_t * idx;
51
- ScopeDeleter<int64_t> del;
52
-
53
- if (precomputed_idx) {
54
- idx = precomputed_idx;
55
- } else {
56
- int64_t * idx0 = new int64_t [n];
57
- del.set (idx0);
58
- quantizer->assign (n, x, idx0);
59
- idx = idx0;
60
- }
61
51
  int64_t n_add = 0;
62
- for (size_t i = 0; i < n; i++) {
63
- idx_t id = xids ? xids[i] : ntotal + i;
64
- idx_t list_no = idx [i];
65
- size_t offset;
66
-
67
- if (list_no >= 0) {
68
- const float *xi = x + i * d;
69
- offset = invlists->add_entry (
70
- list_no, id, (const uint8_t*) xi);
71
- n_add++;
72
- } else {
73
- offset = 0;
52
+
53
+ DirectMapAdd dm_adder(direct_map, n, xids);
54
+
55
+ #pragma omp parallel reduction(+ : n_add)
56
+ {
57
+ int nt = omp_get_num_threads();
58
+ int rank = omp_get_thread_num();
59
+
60
+ // each thread takes care of a subset of lists
61
+ for (size_t i = 0; i < n; i++) {
62
+ idx_t list_no = coarse_idx[i];
63
+
64
+ if (list_no >= 0 && list_no % nt == rank) {
65
+ idx_t id = xids ? xids[i] : ntotal + i;
66
+ const float* xi = x + i * d;
67
+ size_t offset =
68
+ invlists->add_entry(list_no, id, (const uint8_t*)xi);
69
+ dm_adder.add(i, list_no, offset);
70
+ n_add++;
71
+ } else if (rank == 0 && list_no == -1) {
72
+ dm_adder.add(i, -1, 0);
73
+ }
74
74
  }
75
- direct_map.add_single_id (id, list_no, offset);
76
75
  }
77
76
 
78
77
  if (verbose) {
79
- printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64 " vectors\n",
80
- n_add, n);
78
+ printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64
79
+ " vectors\n",
80
+ n_add,
81
+ n);
81
82
  }
82
83
  ntotal += n;
83
84
  }
84
85
 
85
- void IndexIVFFlat::encode_vectors(idx_t n, const float* x,
86
- const idx_t * list_nos,
87
- uint8_t * codes,
88
- bool include_listnos) const
89
- {
86
+ void IndexIVFFlat::encode_vectors(
87
+ idx_t n,
88
+ const float* x,
89
+ const idx_t* list_nos,
90
+ uint8_t* codes,
91
+ bool include_listnos) const {
90
92
  if (!include_listnos) {
91
- memcpy (codes, x, code_size * n);
93
+ memcpy(codes, x, code_size * n);
92
94
  } else {
93
- size_t coarse_size = coarse_code_size ();
95
+ size_t coarse_size = coarse_code_size();
94
96
  for (size_t i = 0; i < n; i++) {
95
- int64_t list_no = list_nos [i];
96
- uint8_t *code = codes + i * (code_size + coarse_size);
97
- const float *xi = x + i * d;
97
+ int64_t list_no = list_nos[i];
98
+ uint8_t* code = codes + i * (code_size + coarse_size);
99
+ const float* xi = x + i * d;
98
100
  if (list_no >= 0) {
99
- encode_listno (list_no, code);
100
- memcpy (code + coarse_size, xi, code_size);
101
+ encode_listno(list_no, code);
102
+ memcpy(code + coarse_size, xi, code_size);
101
103
  } else {
102
- memset (code, 0, code_size + coarse_size);
104
+ memset(code, 0, code_size + coarse_size);
103
105
  }
104
-
105
106
  }
106
107
  }
107
108
  }
108
109
 
109
- void IndexIVFFlat::sa_decode (idx_t n, const uint8_t *bytes,
110
- float *x) const
111
- {
112
- size_t coarse_size = coarse_code_size ();
110
+ void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
111
+ size_t coarse_size = coarse_code_size();
113
112
  for (size_t i = 0; i < n; i++) {
114
- const uint8_t *code = bytes + i * (code_size + coarse_size);
115
- float *xi = x + i * d;
116
- memcpy (xi, code + coarse_size, code_size);
113
+ const uint8_t* code = bytes + i * (code_size + coarse_size);
114
+ float* xi = x + i * d;
115
+ memcpy(xi, code + coarse_size, code_size);
117
116
  }
118
117
  }
119
118
 
120
-
121
119
  namespace {
122
120
 
123
-
124
- template<MetricType metric, class C>
125
- struct IVFFlatScanner: InvertedListScanner {
121
+ template <MetricType metric, class C>
122
+ struct IVFFlatScanner : InvertedListScanner {
126
123
  size_t d;
127
124
  bool store_pairs;
128
125
 
129
- IVFFlatScanner(size_t d, bool store_pairs):
130
- d(d), store_pairs(store_pairs) {}
126
+ IVFFlatScanner(size_t d, bool store_pairs)
127
+ : d(d), store_pairs(store_pairs) {}
131
128
 
132
- const float *xi;
133
- void set_query (const float *query) override {
129
+ const float* xi;
130
+ void set_query(const float* query) override {
134
131
  this->xi = query;
135
132
  }
136
133
 
137
134
  idx_t list_no;
138
- void set_list (idx_t list_no, float /* coarse_dis */) override {
135
+ void set_list(idx_t list_no, float /* coarse_dis */) override {
139
136
  this->list_no = list_no;
140
137
  }
141
138
 
142
- float distance_to_code (const uint8_t *code) const override {
143
- const float *yj = (float*)code;
144
- float dis = metric == METRIC_INNER_PRODUCT ?
145
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
139
+ float distance_to_code(const uint8_t* code) const override {
140
+ const float* yj = (float*)code;
141
+ float dis = metric == METRIC_INNER_PRODUCT
142
+ ? fvec_inner_product(xi, yj, d)
143
+ : fvec_L2sqr(xi, yj, d);
146
144
  return dis;
147
145
  }
148
146
 
149
- size_t scan_codes (size_t list_size,
150
- const uint8_t *codes,
151
- const idx_t *ids,
152
- float *simi, idx_t *idxi,
153
- size_t k) const override
154
- {
155
- const float *list_vecs = (const float*)codes;
147
+ size_t scan_codes(
148
+ size_t list_size,
149
+ const uint8_t* codes,
150
+ const idx_t* ids,
151
+ float* simi,
152
+ idx_t* idxi,
153
+ size_t k) const override {
154
+ const float* list_vecs = (const float*)codes;
156
155
  size_t nup = 0;
157
156
  for (size_t j = 0; j < list_size; j++) {
158
- const float * yj = list_vecs + d * j;
159
- float dis = metric == METRIC_INNER_PRODUCT ?
160
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
161
- if (C::cmp (simi[0], dis)) {
162
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
163
- heap_replace_top<C> (k, simi, idxi, dis, id);
157
+ const float* yj = list_vecs + d * j;
158
+ float dis = metric == METRIC_INNER_PRODUCT
159
+ ? fvec_inner_product(xi, yj, d)
160
+ : fvec_L2sqr(xi, yj, d);
161
+ if (C::cmp(simi[0], dis)) {
162
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
163
+ heap_replace_top<C>(k, simi, idxi, dis, id);
164
164
  nup++;
165
165
  }
166
166
  }
167
167
  return nup;
168
168
  }
169
169
 
170
- void scan_codes_range (size_t list_size,
171
- const uint8_t *codes,
172
- const idx_t *ids,
173
- float radius,
174
- RangeQueryResult & res) const override
175
- {
176
- const float *list_vecs = (const float*)codes;
170
+ void scan_codes_range(
171
+ size_t list_size,
172
+ const uint8_t* codes,
173
+ const idx_t* ids,
174
+ float radius,
175
+ RangeQueryResult& res) const override {
176
+ const float* list_vecs = (const float*)codes;
177
177
  for (size_t j = 0; j < list_size; j++) {
178
- const float * yj = list_vecs + d * j;
179
- float dis = metric == METRIC_INNER_PRODUCT ?
180
- fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
181
- if (C::cmp (radius, dis)) {
182
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
183
- res.add (dis, id);
178
+ const float* yj = list_vecs + d * j;
179
+ float dis = metric == METRIC_INNER_PRODUCT
180
+ ? fvec_inner_product(xi, yj, d)
181
+ : fvec_L2sqr(xi, yj, d);
182
+ if (C::cmp(radius, dis)) {
183
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
184
+ res.add(dis, id);
184
185
  }
185
186
  }
186
187
  }
187
-
188
-
189
188
  };
190
189
 
191
-
192
190
  } // anonymous namespace
193
191
 
194
-
195
-
196
- InvertedListScanner* IndexIVFFlat::get_InvertedListScanner
197
- (bool store_pairs) const
198
- {
192
+ InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
193
+ bool store_pairs) const {
199
194
  if (metric_type == METRIC_INNER_PRODUCT) {
200
- return new IVFFlatScanner<
201
- METRIC_INNER_PRODUCT, CMin<float, int64_t> > (d, store_pairs);
195
+ return new IVFFlatScanner<METRIC_INNER_PRODUCT, CMin<float, int64_t>>(
196
+ d, store_pairs);
202
197
  } else if (metric_type == METRIC_L2) {
203
- return new IVFFlatScanner<
204
- METRIC_L2, CMax<float, int64_t> >(d, store_pairs);
198
+ return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>>(
199
+ d, store_pairs);
205
200
  } else {
206
201
  FAISS_THROW_MSG("metric type not supported");
207
202
  }
208
203
  return nullptr;
209
204
  }
210
205
 
211
-
212
-
213
-
214
- void IndexIVFFlat::reconstruct_from_offset (int64_t list_no, int64_t offset,
215
- float* recons) const
216
- {
217
- memcpy (recons, invlists->get_single_code (list_no, offset), code_size);
206
+ void IndexIVFFlat::reconstruct_from_offset(
207
+ int64_t list_no,
208
+ int64_t offset,
209
+ float* recons) const {
210
+ memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
218
211
  }
219
212
 
220
213
  /*****************************************
221
214
  * IndexIVFFlatDedup implementation
222
215
  ******************************************/
223
216
 
224
- IndexIVFFlatDedup::IndexIVFFlatDedup (
225
- Index * quantizer, size_t d, size_t nlist_,
226
- MetricType metric_type):
227
- IndexIVFFlat (quantizer, d, nlist_, metric_type)
228
- {}
229
-
217
+ IndexIVFFlatDedup::IndexIVFFlatDedup(
218
+ Index* quantizer,
219
+ size_t d,
220
+ size_t nlist_,
221
+ MetricType metric_type)
222
+ : IndexIVFFlat(quantizer, d, nlist_, metric_type) {}
230
223
 
231
- void IndexIVFFlatDedup::train(idx_t n, const float* x)
232
- {
224
+ void IndexIVFFlatDedup::train(idx_t n, const float* x) {
233
225
  std::unordered_map<uint64_t, idx_t> map;
234
- float * x2 = new float [n * d];
235
- ScopeDeleter<float> del (x2);
226
+ float* x2 = new float[n * d];
227
+ ScopeDeleter<float> del(x2);
236
228
 
237
229
  int64_t n2 = 0;
238
230
  for (int64_t i = 0; i < n; i++) {
239
- uint64_t hash = hash_bytes((uint8_t *)(x + i * d), code_size);
231
+ uint64_t hash = hash_bytes((uint8_t*)(x + i * d), code_size);
240
232
  if (map.count(hash) &&
241
- !memcmp (x2 + map[hash] * d, x + i * d, code_size)) {
233
+ !memcmp(x2 + map[hash] * d, x + i * d, code_size)) {
242
234
  // is duplicate, skip
243
235
  } else {
244
- map [hash] = n2;
245
- memcpy (x2 + n2 * d, x + i * d, code_size);
246
- n2 ++;
236
+ map[hash] = n2;
237
+ memcpy(x2 + n2 * d, x + i * d, code_size);
238
+ n2++;
247
239
  }
248
240
  }
249
241
  if (verbose) {
250
- printf ("IndexIVFFlatDedup::train: train on %" PRId64 " points after dedup "
251
- "(was %" PRId64 " points)\n", n2, n);
242
+ printf("IndexIVFFlatDedup::train: train on %" PRId64
243
+ " points after dedup "
244
+ "(was %" PRId64 " points)\n",
245
+ n2,
246
+ n);
252
247
  }
253
- IndexIVFFlat::train (n2, x2);
248
+ IndexIVFFlat::train(n2, x2);
254
249
  }
255
250
 
251
+ void IndexIVFFlatDedup::add_with_ids(
252
+ idx_t na,
253
+ const float* x,
254
+ const idx_t* xids) {
255
+ FAISS_THROW_IF_NOT(is_trained);
256
+ assert(invlists);
257
+ FAISS_THROW_IF_NOT_MSG(
258
+ direct_map.no(), "IVFFlatDedup not implemented with direct_map");
259
+ int64_t* idx = new int64_t[na];
260
+ ScopeDeleter<int64_t> del(idx);
261
+ quantizer->assign(na, x, idx);
256
262
 
263
+ int64_t n_add = 0, n_dup = 0;
257
264
 
258
- void IndexIVFFlatDedup::add_with_ids(
259
- idx_t na, const float* x, const idx_t* xids)
260
- {
265
+ #pragma omp parallel reduction(+ : n_add, n_dup)
266
+ {
267
+ int nt = omp_get_num_threads();
268
+ int rank = omp_get_thread_num();
261
269
 
262
- FAISS_THROW_IF_NOT (is_trained);
263
- assert (invlists);
264
- FAISS_THROW_IF_NOT_MSG (direct_map.no(),
265
- "IVFFlatDedup not implemented with direct_map");
266
- int64_t * idx = new int64_t [na];
267
- ScopeDeleter<int64_t> del (idx);
268
- quantizer->assign (na, x, idx);
270
+ // each thread takes care of a subset of lists
271
+ for (size_t i = 0; i < na; i++) {
272
+ int64_t list_no = idx[i];
269
273
 
270
- int64_t n_add = 0, n_dup = 0;
271
- // TODO make a omp loop with this
272
- for (size_t i = 0; i < na; i++) {
273
- idx_t id = xids ? xids[i] : ntotal + i;
274
- int64_t list_no = idx [i];
274
+ if (list_no < 0 || list_no % nt != rank) {
275
+ continue;
276
+ }
275
277
 
276
- if (list_no < 0) {
277
- continue;
278
- }
279
- const float *xi = x + i * d;
278
+ idx_t id = xids ? xids[i] : ntotal + i;
279
+ const float* xi = x + i * d;
280
280
 
281
- // search if there is already an entry with that id
282
- InvertedLists::ScopedCodes codes (invlists, list_no);
281
+ // search if there is already an entry with that id
282
+ InvertedLists::ScopedCodes codes(invlists, list_no);
283
283
 
284
- int64_t n = invlists->list_size (list_no);
285
- int64_t offset = -1;
286
- for (int64_t o = 0; o < n; o++) {
287
- if (!memcmp (codes.get() + o * code_size,
288
- xi, code_size)) {
289
- offset = o;
290
- break;
284
+ int64_t n = invlists->list_size(list_no);
285
+ int64_t offset = -1;
286
+ for (int64_t o = 0; o < n; o++) {
287
+ if (!memcmp(codes.get() + o * code_size, xi, code_size)) {
288
+ offset = o;
289
+ break;
290
+ }
291
291
  }
292
- }
293
292
 
294
- if (offset == -1) { // not found
295
- invlists->add_entry (list_no, id, (const uint8_t*) xi);
296
- } else {
297
- // mark equivalence
298
- idx_t id2 = invlists->get_single_id (list_no, offset);
299
- std::pair<idx_t, idx_t> pair (id2, id);
300
- instances.insert (pair);
301
- n_dup ++;
293
+ if (offset == -1) { // not found
294
+ invlists->add_entry(list_no, id, (const uint8_t*)xi);
295
+ } else {
296
+ // mark equivalence
297
+ idx_t id2 = invlists->get_single_id(list_no, offset);
298
+ std::pair<idx_t, idx_t> pair(id2, id);
299
+
300
+ #pragma omp critical
301
+ // executed by one thread at a time
302
+ instances.insert(pair);
303
+
304
+ n_dup++;
305
+ }
306
+ n_add++;
302
307
  }
303
- n_add++;
304
308
  }
305
309
  if (verbose) {
306
- printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64 " vectors"
310
+ printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64
311
+ " vectors"
307
312
  " (out of which %" PRId64 " are duplicates)\n",
308
- n_add, na, n_dup);
313
+ n_add,
314
+ na,
315
+ n_dup);
309
316
  }
310
317
  ntotal += n_add;
311
318
  }
312
319
 
313
- void IndexIVFFlatDedup::search_preassigned (
314
- idx_t n, const float *x, idx_t k,
315
- const idx_t *assign,
316
- const float *centroid_dis,
317
- float *distances, idx_t *labels,
318
- bool store_pairs,
319
- const IVFSearchParameters *params,
320
- IndexIVFStats *stats) const
321
- {
322
- FAISS_THROW_IF_NOT_MSG (
323
- !store_pairs, "store_pairs not supported in IVFDedup");
324
-
325
- IndexIVFFlat::search_preassigned (n, x, k, assign, centroid_dis,
326
- distances, labels, false,
327
- params);
328
-
329
- std::vector <idx_t> labels2 (k);
330
- std::vector <float> dis2 (k);
320
+ void IndexIVFFlatDedup::search_preassigned(
321
+ idx_t n,
322
+ const float* x,
323
+ idx_t k,
324
+ const idx_t* assign,
325
+ const float* centroid_dis,
326
+ float* distances,
327
+ idx_t* labels,
328
+ bool store_pairs,
329
+ const IVFSearchParameters* params,
330
+ IndexIVFStats* stats) const {
331
+ FAISS_THROW_IF_NOT_MSG(
332
+ !store_pairs, "store_pairs not supported in IVFDedup");
333
+
334
+ IndexIVFFlat::search_preassigned(
335
+ n, x, k, assign, centroid_dis, distances, labels, false, params);
336
+
337
+ std::vector<idx_t> labels2(k);
338
+ std::vector<float> dis2(k);
331
339
 
332
340
  for (int64_t i = 0; i < n; i++) {
333
- idx_t *labels1 = labels + i * k;
334
- float *dis1 = distances + i * k;
341
+ idx_t* labels1 = labels + i * k;
342
+ float* dis1 = distances + i * k;
335
343
  int64_t j = 0;
336
344
  for (; j < k; j++) {
337
- if (instances.find (labels1[j]) != instances.end ()) {
345
+ if (instances.find(labels1[j]) != instances.end()) {
338
346
  // a duplicate: special handling
339
347
  break;
340
348
  }
@@ -344,11 +352,11 @@ void IndexIVFFlatDedup::search_preassigned (
344
352
  int64_t j0 = j;
345
353
  int64_t rp = j;
346
354
  while (j < k) {
347
- auto range = instances.equal_range (labels1[rp]);
355
+ auto range = instances.equal_range(labels1[rp]);
348
356
  float dis = dis1[rp];
349
357
  labels2[j] = labels1[rp];
350
358
  dis2[j] = dis;
351
- j ++;
359
+ j++;
352
360
  for (auto it = range.first; j < k && it != range.second; ++it) {
353
361
  labels2[j] = it->second;
354
362
  dis2[j] = dis;
@@ -356,21 +364,18 @@ void IndexIVFFlatDedup::search_preassigned (
356
364
  }
357
365
  rp++;
358
366
  }
359
- memcpy (labels1 + j0, labels2.data() + j0,
360
- sizeof(labels1[0]) * (k - j0));
361
- memcpy (dis1 + j0, dis2.data() + j0,
362
- sizeof(dis2[0]) * (k - j0));
367
+ memcpy(labels1 + j0,
368
+ labels2.data() + j0,
369
+ sizeof(labels1[0]) * (k - j0));
370
+ memcpy(dis1 + j0, dis2.data() + j0, sizeof(dis2[0]) * (k - j0));
363
371
  }
364
372
  }
365
-
366
373
  }
367
374
 
368
-
369
- size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
370
- {
375
+ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel) {
371
376
  std::unordered_map<idx_t, idx_t> replace;
372
- std::vector<std::pair<idx_t, idx_t> > toadd;
373
- for (auto it = instances.begin(); it != instances.end(); ) {
377
+ std::vector<std::pair<idx_t, idx_t>> toadd;
378
+ for (auto it = instances.begin(); it != instances.end();) {
374
379
  if (sel.is_member(it->first)) {
375
380
  // then we erase this entry
376
381
  if (!sel.is_member(it->second)) {
@@ -378,8 +383,8 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
378
383
  if (replace.count(it->first) == 0) {
379
384
  replace[it->first] = it->second;
380
385
  } else { // remember we should add an element
381
- std::pair<idx_t, idx_t> new_entry (
382
- replace[it->first], it->second);
386
+ std::pair<idx_t, idx_t> new_entry(
387
+ replace[it->first], it->second);
383
388
  toadd.push_back(new_entry);
384
389
  }
385
390
  }
@@ -393,32 +398,34 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
393
398
  }
394
399
  }
395
400
 
396
- instances.insert (toadd.begin(), toadd.end());
401
+ instances.insert(toadd.begin(), toadd.end());
397
402
 
398
403
  // mostly copied from IndexIVF.cpp
399
404
 
400
- FAISS_THROW_IF_NOT_MSG (direct_map.no(),
401
- "direct map remove not implemented");
405
+ FAISS_THROW_IF_NOT_MSG(
406
+ direct_map.no(), "direct map remove not implemented");
402
407
 
403
408
  std::vector<int64_t> toremove(nlist);
404
409
 
405
410
  #pragma omp parallel for
406
411
  for (int64_t i = 0; i < nlist; i++) {
407
- int64_t l0 = invlists->list_size (i), l = l0, j = 0;
408
- InvertedLists::ScopedIds idsi (invlists, i);
412
+ int64_t l0 = invlists->list_size(i), l = l0, j = 0;
413
+ InvertedLists::ScopedIds idsi(invlists, i);
409
414
  while (j < l) {
410
- if (sel.is_member (idsi[j])) {
415
+ if (sel.is_member(idsi[j])) {
411
416
  if (replace.count(idsi[j]) == 0) {
412
417
  l--;
413
- invlists->update_entry (
414
- i, j,
415
- invlists->get_single_id (i, l),
416
- InvertedLists::ScopedCodes (invlists, i, l).get());
418
+ invlists->update_entry(
419
+ i,
420
+ j,
421
+ invlists->get_single_id(i, l),
422
+ InvertedLists::ScopedCodes(invlists, i, l).get());
417
423
  } else {
418
- invlists->update_entry (
419
- i, j,
420
- replace[idsi[j]],
421
- InvertedLists::ScopedCodes (invlists, i, j).get());
424
+ invlists->update_entry(
425
+ i,
426
+ j,
427
+ replace[idsi[j]],
428
+ InvertedLists::ScopedCodes(invlists, i, j).get());
422
429
  j++;
423
430
  }
424
431
  } else {
@@ -432,37 +439,28 @@ size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel)
432
439
  for (int64_t i = 0; i < nlist; i++) {
433
440
  if (toremove[i] > 0) {
434
441
  nremove += toremove[i];
435
- invlists->resize(
436
- i, invlists->list_size(i) - toremove[i]);
442
+ invlists->resize(i, invlists->list_size(i) - toremove[i]);
437
443
  }
438
444
  }
439
445
  ntotal -= nremove;
440
446
  return nremove;
441
447
  }
442
448
 
443
-
444
449
  void IndexIVFFlatDedup::range_search(
445
- idx_t ,
446
- const float* ,
447
- float ,
448
- RangeSearchResult* ) const
449
- {
450
- FAISS_THROW_MSG ("not implemented");
450
+ idx_t,
451
+ const float*,
452
+ float,
453
+ RangeSearchResult*) const {
454
+ FAISS_THROW_MSG("not implemented");
451
455
  }
452
456
 
453
- void IndexIVFFlatDedup::update_vectors (int , const idx_t *, const float *)
454
- {
455
- FAISS_THROW_MSG ("not implemented");
457
+ void IndexIVFFlatDedup::update_vectors(int, const idx_t*, const float*) {
458
+ FAISS_THROW_MSG("not implemented");
456
459
  }
457
460
 
458
-
459
- void IndexIVFFlatDedup::reconstruct_from_offset (
460
- int64_t , int64_t , float* ) const
461
- {
462
- FAISS_THROW_MSG ("not implemented");
461
+ void IndexIVFFlatDedup::reconstruct_from_offset(int64_t, int64_t, float*)
462
+ const {
463
+ FAISS_THROW_MSG("not implemented");
463
464
  }
464
465
 
465
-
466
-
467
-
468
466
  } // namespace faiss