faiss 0.1.7 → 0.2.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -12,7 +12,6 @@
12
12
  #include <faiss/IndexIVF.h>
13
13
  #include <faiss/impl/AuxIndexStructures.h>
14
14
 
15
-
16
15
  namespace faiss {
17
16
 
18
17
  /**
@@ -22,15 +21,14 @@ namespace faiss {
22
21
  */
23
22
 
24
23
  struct ScalarQuantizer {
25
-
26
24
  enum QuantizerType {
27
- QT_8bit, ///< 8 bits per component
28
- QT_4bit, ///< 4 bits per component
29
- QT_8bit_uniform, ///< same, shared range for all dimensions
25
+ QT_8bit, ///< 8 bits per component
26
+ QT_4bit, ///< 4 bits per component
27
+ QT_8bit_uniform, ///< same, shared range for all dimensions
30
28
  QT_4bit_uniform,
31
29
  QT_fp16,
32
- QT_8bit_direct, ///< fast indexing of uint8s
33
- QT_6bit, ///< 6 bits per component
30
+ QT_8bit_direct, ///< fast indexing of uint8s
31
+ QT_6bit, ///< 6 bits per component
34
32
  };
35
33
 
36
34
  QuantizerType qtype;
@@ -41,10 +39,10 @@ struct ScalarQuantizer {
41
39
 
42
40
  // rangestat_arg.
43
41
  enum RangeStat {
44
- RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)]
45
- RS_meanstd, ///< [mean - std * rs, mean + std * rs]
46
- RS_quantiles, ///< [Q(rs), Q(1-rs)]
47
- RS_optim, ///< alternate optimization of reconstruction error
42
+ RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)]
43
+ RS_meanstd, ///< [mean - std * rs, mean + std * rs]
44
+ RS_quantiles, ///< [Q(rs), Q(1-rs)]
45
+ RS_optim, ///< alternate optimization of reconstruction error
48
46
  };
49
47
 
50
48
  RangeStat rangestat;
@@ -62,29 +60,35 @@ struct ScalarQuantizer {
62
60
  /// trained values (including the range)
63
61
  std::vector<float> trained;
64
62
 
65
- ScalarQuantizer (size_t d, QuantizerType qtype);
66
- ScalarQuantizer ();
63
+ ScalarQuantizer(size_t d, QuantizerType qtype);
64
+ ScalarQuantizer();
67
65
 
68
66
  /// updates internal values based on qtype and d
69
- void set_derived_sizes ();
67
+ void set_derived_sizes();
70
68
 
71
- void train (size_t n, const float *x);
69
+ void train(size_t n, const float* x);
72
70
 
73
71
  /// Used by an IVF index to train based on the residuals
74
- void train_residual (size_t n,
75
- const float *x,
76
- Index *quantizer,
77
- bool by_residual,
78
- bool verbose);
79
-
80
- /// same as compute_code for several vectors
81
- void compute_codes (const float * x,
82
- uint8_t * codes,
83
- size_t n) const ;
84
-
85
- /// decode a vector from a given code (or n vectors if third argument)
86
- void decode (const uint8_t *code, float *x, size_t n) const;
87
-
72
+ void train_residual(
73
+ size_t n,
74
+ const float* x,
75
+ Index* quantizer,
76
+ bool by_residual,
77
+ bool verbose);
78
+
79
+ /** Encode a set of vectors
80
+ *
81
+ * @param x vectors to encode, size n * d
82
+ * @param codes output codes, size n * code_size
83
+ */
84
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const;
85
+
86
+ /** Decode a set of vectors
87
+ *
88
+ * @param codes codes to decode, size n * code_size
89
+ * @param x output vectors, size n * d
90
+ */
91
+ void decode(const uint8_t* code, float* x, size_t n) const;
88
92
 
89
93
  /*****************************************************
90
94
  * Objects that provide methods for encoding/decoding, distance
@@ -93,34 +97,32 @@ struct ScalarQuantizer {
93
97
 
94
98
  struct Quantizer {
95
99
  // encodes one vector. Assumes code is filled with 0s on input!
96
- virtual void encode_vector(const float *x, uint8_t *code) const = 0;
97
- virtual void decode_vector(const uint8_t *code, float *x) const = 0;
100
+ virtual void encode_vector(const float* x, uint8_t* code) const = 0;
101
+ virtual void decode_vector(const uint8_t* code, float* x) const = 0;
98
102
 
99
103
  virtual ~Quantizer() {}
100
104
  };
101
105
 
102
- Quantizer * select_quantizer() const;
106
+ Quantizer* select_quantizer() const;
103
107
 
104
- struct SQDistanceComputer: DistanceComputer {
105
-
106
- const float *q;
107
- const uint8_t *codes;
108
+ struct SQDistanceComputer : DistanceComputer {
109
+ const float* q;
110
+ const uint8_t* codes;
108
111
  size_t code_size;
109
112
 
110
- SQDistanceComputer (): q(nullptr), codes (nullptr), code_size (0)
111
- {}
113
+ SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
112
114
 
115
+ virtual float query_to_code(const uint8_t* code) const = 0;
113
116
  };
114
117
 
115
- SQDistanceComputer *get_distance_computer (MetricType metric = METRIC_L2)
116
- const;
117
-
118
- InvertedListScanner *select_InvertedListScanner
119
- (MetricType mt, const Index *quantizer, bool store_pairs,
120
- bool by_residual=false) const;
118
+ SQDistanceComputer* get_distance_computer(
119
+ MetricType metric = METRIC_L2) const;
121
120
 
121
+ InvertedListScanner* select_InvertedListScanner(
122
+ MetricType mt,
123
+ const Index* quantizer,
124
+ bool store_pairs,
125
+ bool by_residual = false) const;
122
126
  };
123
127
 
124
-
125
-
126
128
  } // namespace faiss
@@ -13,180 +13,178 @@ namespace faiss {
13
13
 
14
14
  template <typename IndexT>
15
15
  ThreadedIndex<IndexT>::ThreadedIndex(bool threaded)
16
- // 0 is default dimension
17
- : ThreadedIndex(0, threaded) {
18
- }
16
+ // 0 is default dimension
17
+ : ThreadedIndex(0, threaded) {}
19
18
 
20
19
  template <typename IndexT>
21
20
  ThreadedIndex<IndexT>::ThreadedIndex(int d, bool threaded)
22
- : IndexT(d),
23
- own_fields(false),
24
- isThreaded_(threaded) {
25
- }
21
+ : IndexT(d), own_fields(false), isThreaded_(threaded) {}
26
22
 
27
23
  template <typename IndexT>
28
24
  ThreadedIndex<IndexT>::~ThreadedIndex() {
29
- for (auto& p : indices_) {
30
- if (isThreaded_) {
31
- // should have worker thread
32
- FAISS_ASSERT((bool) p.second);
33
-
34
- // This will also flush all pending work
35
- p.second->stop();
36
- p.second->waitForThreadExit();
37
- } else {
38
- // should not have worker thread
39
- FAISS_ASSERT(!(bool) p.second);
40
- }
41
-
42
- if (own_fields) {
43
- delete p.first;
25
+ for (auto& p : indices_) {
26
+ if (isThreaded_) {
27
+ // should have worker thread
28
+ FAISS_ASSERT((bool)p.second);
29
+
30
+ // This will also flush all pending work
31
+ p.second->stop();
32
+ p.second->waitForThreadExit();
33
+ } else {
34
+ // should not have worker thread
35
+ FAISS_ASSERT(!(bool)p.second);
36
+ }
37
+
38
+ if (own_fields) {
39
+ delete p.first;
40
+ }
44
41
  }
45
- }
46
42
  }
47
43
 
48
44
  template <typename IndexT>
49
45
  void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
50
- // We inherit the dimension from the first index added to us if we don't have
51
- // a set dimension
52
- if (indices_.empty() && this->d == 0) {
53
- this->d = index->d;
54
- }
55
-
56
- // The new index must match our set dimension
57
- FAISS_THROW_IF_NOT_FMT(this->d == index->d,
58
- "addIndex: dimension mismatch for "
59
- "newly added index; expecting dim %d, "
60
- "new index has dim %d",
61
- this->d, index->d);
62
-
63
- if (!indices_.empty()) {
64
- auto& existing = indices_.front().first;
65
-
66
- FAISS_THROW_IF_NOT_MSG(index->metric_type == existing->metric_type,
67
- "addIndex: newly added index is "
68
- "of different metric type than old index");
69
-
70
- // Make sure this index is not duplicated
71
- for (auto& p : indices_) {
72
- FAISS_THROW_IF_NOT_MSG(p.first != index,
73
- "addIndex: attempting to add index "
74
- "that is already in the collection");
46
+ // We inherit the dimension from the first index added to us if we don't
47
+ // have a set dimension
48
+ if (indices_.empty() && this->d == 0) {
49
+ this->d = index->d;
50
+ }
51
+
52
+ // The new index must match our set dimension
53
+ FAISS_THROW_IF_NOT_FMT(
54
+ this->d == index->d,
55
+ "addIndex: dimension mismatch for "
56
+ "newly added index; expecting dim %d, "
57
+ "new index has dim %d",
58
+ this->d,
59
+ index->d);
60
+
61
+ if (!indices_.empty()) {
62
+ auto& existing = indices_.front().first;
63
+
64
+ FAISS_THROW_IF_NOT_MSG(
65
+ index->metric_type == existing->metric_type,
66
+ "addIndex: newly added index is "
67
+ "of different metric type than old index");
68
+
69
+ // Make sure this index is not duplicated
70
+ for (auto& p : indices_) {
71
+ FAISS_THROW_IF_NOT_MSG(
72
+ p.first != index,
73
+ "addIndex: attempting to add index "
74
+ "that is already in the collection");
75
+ }
75
76
  }
76
- }
77
77
 
78
- indices_.emplace_back(
79
- std::make_pair(
80
- index,
81
- std::unique_ptr<WorkerThread>(isThreaded_ ?
82
- new WorkerThread : nullptr)));
78
+ indices_.emplace_back(std::make_pair(
79
+ index,
80
+ std::unique_ptr<WorkerThread>(
81
+ isThreaded_ ? new WorkerThread : nullptr)));
83
82
 
84
- onAfterAddIndex(index);
83
+ onAfterAddIndex(index);
85
84
  }
86
85
 
87
86
  template <typename IndexT>
88
87
  void ThreadedIndex<IndexT>::removeIndex(IndexT* index) {
89
- for (auto it = indices_.begin(); it != indices_.end(); ++it) {
90
- if (it->first == index) {
91
- // This is our index; stop the worker thread before removing it,
92
- // to ensure that it has finished before function exit
93
- if (isThreaded_) {
94
- // should have worker thread
95
- FAISS_ASSERT((bool) it->second);
96
- it->second->stop();
97
- it->second->waitForThreadExit();
98
- } else {
99
- // should not have worker thread
100
- FAISS_ASSERT(!(bool) it->second);
101
- }
102
-
103
- indices_.erase(it);
104
- onAfterRemoveIndex(index);
105
-
106
- if (own_fields) {
107
- delete index;
108
- }
109
-
110
- return;
88
+ for (auto it = indices_.begin(); it != indices_.end(); ++it) {
89
+ if (it->first == index) {
90
+ // This is our index; stop the worker thread before removing it,
91
+ // to ensure that it has finished before function exit
92
+ if (isThreaded_) {
93
+ // should have worker thread
94
+ FAISS_ASSERT((bool)it->second);
95
+ it->second->stop();
96
+ it->second->waitForThreadExit();
97
+ } else {
98
+ // should not have worker thread
99
+ FAISS_ASSERT(!(bool)it->second);
100
+ }
101
+
102
+ indices_.erase(it);
103
+ onAfterRemoveIndex(index);
104
+
105
+ if (own_fields) {
106
+ delete index;
107
+ }
108
+
109
+ return;
110
+ }
111
111
  }
112
- }
113
112
 
114
- // could not find our index
115
- FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
113
+ // could not find our index
114
+ FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
116
115
  }
117
116
 
118
117
  template <typename IndexT>
119
118
  void ThreadedIndex<IndexT>::runOnIndex(std::function<void(int, IndexT*)> f) {
120
- if (isThreaded_) {
121
- std::vector<std::future<bool>> v;
122
-
123
- for (int i = 0; i < this->indices_.size(); ++i) {
124
- auto& p = this->indices_[i];
125
- auto indexPtr = p.first;
126
- v.emplace_back(p.second->add([f, i, indexPtr](){ f(i, indexPtr); }));
127
- }
119
+ if (isThreaded_) {
120
+ std::vector<std::future<bool>> v;
128
121
 
129
- waitAndHandleFutures(v);
130
- } else {
131
- // Multiple exceptions may be thrown; gather them as we encounter them,
132
- // while letting everything else run to completion
133
- std::vector<std::pair<int, std::exception_ptr>> exceptions;
122
+ for (int i = 0; i < this->indices_.size(); ++i) {
123
+ auto& p = this->indices_[i];
124
+ auto indexPtr = p.first;
125
+ v.emplace_back(
126
+ p.second->add([f, i, indexPtr]() { f(i, indexPtr); }));
127
+ }
134
128
 
135
- for (int i = 0; i < this->indices_.size(); ++i) {
136
- auto& p = this->indices_[i];
137
- try {
138
- f(i, p.first);
139
- } catch (...) {
140
- exceptions.emplace_back(std::make_pair(i, std::current_exception()));
141
- }
129
+ waitAndHandleFutures(v);
130
+ } else {
131
+ // Multiple exceptions may be thrown; gather them as we encounter them,
132
+ // while letting everything else run to completion
133
+ std::vector<std::pair<int, std::exception_ptr>> exceptions;
134
+
135
+ for (int i = 0; i < this->indices_.size(); ++i) {
136
+ auto& p = this->indices_[i];
137
+ try {
138
+ f(i, p.first);
139
+ } catch (...) {
140
+ exceptions.emplace_back(
141
+ std::make_pair(i, std::current_exception()));
142
+ }
143
+ }
144
+
145
+ handleExceptions(exceptions);
142
146
  }
143
-
144
- handleExceptions(exceptions);
145
- }
146
147
  }
147
148
 
148
149
  template <typename IndexT>
149
150
  void ThreadedIndex<IndexT>::runOnIndex(
150
- std::function<void(int, const IndexT*)> f) const {
151
- const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
152
- [f](int i, IndexT* idx){ f(i, idx); });
151
+ std::function<void(int, const IndexT*)> f) const {
152
+ const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
153
+ [f](int i, IndexT* idx) { f(i, idx); });
153
154
  }
154
155
 
155
156
  template <typename IndexT>
156
157
  void ThreadedIndex<IndexT>::reset() {
157
- runOnIndex([](int, IndexT* index){ index->reset(); });
158
- this->ntotal = 0;
159
- this->is_trained = false;
158
+ runOnIndex([](int, IndexT* index) { index->reset(); });
159
+ this->ntotal = 0;
160
+ this->is_trained = false;
160
161
  }
161
162
 
162
163
  template <typename IndexT>
163
- void
164
- ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {
165
- }
164
+ void ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {}
166
165
 
167
166
  template <typename IndexT>
168
- void
169
- ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {
170
- }
167
+ void ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {}
171
168
 
172
169
  template <typename IndexT>
173
- void
174
- ThreadedIndex<IndexT>::waitAndHandleFutures(std::vector<std::future<bool>>& v) {
175
- // Blocking wait for completion for all of the indices, capturing any
176
- // exceptions that are generated
177
- std::vector<std::pair<int, std::exception_ptr>> exceptions;
178
-
179
- for (int i = 0; i < v.size(); ++i) {
180
- auto& fut = v[i];
181
-
182
- try {
183
- fut.get();
184
- } catch (...) {
185
- exceptions.emplace_back(std::make_pair(i, std::current_exception()));
170
+ void ThreadedIndex<IndexT>::waitAndHandleFutures(
171
+ std::vector<std::future<bool>>& v) {
172
+ // Blocking wait for completion for all of the indices, capturing any
173
+ // exceptions that are generated
174
+ std::vector<std::pair<int, std::exception_ptr>> exceptions;
175
+
176
+ for (int i = 0; i < v.size(); ++i) {
177
+ auto& fut = v[i];
178
+
179
+ try {
180
+ fut.get();
181
+ } catch (...) {
182
+ exceptions.emplace_back(
183
+ std::make_pair(i, std::current_exception()));
184
+ }
186
185
  }
187
- }
188
186
 
189
- handleExceptions(exceptions);
187
+ handleExceptions(exceptions);
190
188
  }
191
189
 
192
- } // namespace
190
+ } // namespace faiss
@@ -19,62 +19,68 @@ namespace faiss {
19
19
  /// The interface to this class itself is not thread safe
20
20
  template <typename IndexT>
21
21
  class ThreadedIndex : public IndexT {
22
- public:
23
- explicit ThreadedIndex(bool threaded);
24
- explicit ThreadedIndex(int d, bool threaded);
25
-
26
- ~ThreadedIndex() override;
27
-
28
- /// override an index that is managed by ourselves.
29
- /// WARNING: once an index is added, it becomes unsafe to touch it from any
30
- /// other thread than that on which is managing it, until we are shut
31
- /// down. Use runOnIndex to perform work on it instead.
32
- void addIndex(IndexT* index);
33
-
34
- /// Remove an index that is managed by ourselves.
35
- /// This will flush all pending work on that index, and then shut
36
- /// down its managing thread, and will remove the index.
37
- void removeIndex(IndexT* index);
38
-
39
- /// Run a function on all indices, in the thread that the index is
40
- /// managed in.
41
- /// Function arguments are (index in collection, index pointer)
42
- void runOnIndex(std::function<void(int, IndexT*)> f);
43
- void runOnIndex(std::function<void(int, const IndexT*)> f) const;
44
-
45
- /// faiss::Index API
46
- /// All indices receive the same call
47
- void reset() override;
48
-
49
- /// Returns the number of sub-indices
50
- int count() const { return indices_.size(); }
51
-
52
- /// Returns the i-th sub-index
53
- IndexT* at(int i) { return indices_[i].first; }
54
-
55
- /// Returns the i-th sub-index (const version)
56
- const IndexT* at(int i) const { return indices_[i].first; }
57
-
58
- /// Whether or not we are responsible for deleting our contained indices
59
- bool own_fields;
60
-
61
- protected:
62
- /// Called just after an index is added
63
- virtual void onAfterAddIndex(IndexT* index);
64
-
65
- /// Called just after an index is removed
66
- virtual void onAfterRemoveIndex(IndexT* index);
67
-
68
- protected:
69
- static void waitAndHandleFutures(std::vector<std::future<bool>>& v);
70
-
71
- /// Collection of Index instances, with their managing worker thread if any
72
- std::vector<std::pair<IndexT*, std::unique_ptr<WorkerThread>>> indices_;
73
-
74
- /// Is this index multi-threaded?
75
- bool isThreaded_;
22
+ public:
23
+ explicit ThreadedIndex(bool threaded);
24
+ explicit ThreadedIndex(int d, bool threaded);
25
+
26
+ ~ThreadedIndex() override;
27
+
28
+ /// override an index that is managed by ourselves.
29
+ /// WARNING: once an index is added, it becomes unsafe to touch it from any
30
+ /// other thread than that on which is managing it, until we are shut
31
+ /// down. Use runOnIndex to perform work on it instead.
32
+ void addIndex(IndexT* index);
33
+
34
+ /// Remove an index that is managed by ourselves.
35
+ /// This will flush all pending work on that index, and then shut
36
+ /// down its managing thread, and will remove the index.
37
+ void removeIndex(IndexT* index);
38
+
39
+ /// Run a function on all indices, in the thread that the index is
40
+ /// managed in.
41
+ /// Function arguments are (index in collection, index pointer)
42
+ void runOnIndex(std::function<void(int, IndexT*)> f);
43
+ void runOnIndex(std::function<void(int, const IndexT*)> f) const;
44
+
45
+ /// faiss::Index API
46
+ /// All indices receive the same call
47
+ void reset() override;
48
+
49
+ /// Returns the number of sub-indices
50
+ int count() const {
51
+ return indices_.size();
52
+ }
53
+
54
+ /// Returns the i-th sub-index
55
+ IndexT* at(int i) {
56
+ return indices_[i].first;
57
+ }
58
+
59
+ /// Returns the i-th sub-index (const version)
60
+ const IndexT* at(int i) const {
61
+ return indices_[i].first;
62
+ }
63
+
64
+ /// Whether or not we are responsible for deleting our contained indices
65
+ bool own_fields;
66
+
67
+ protected:
68
+ /// Called just after an index is added
69
+ virtual void onAfterAddIndex(IndexT* index);
70
+
71
+ /// Called just after an index is removed
72
+ virtual void onAfterRemoveIndex(IndexT* index);
73
+
74
+ protected:
75
+ static void waitAndHandleFutures(std::vector<std::future<bool>>& v);
76
+
77
+ /// Collection of Index instances, with their managing worker thread if any
78
+ std::vector<std::pair<IndexT*, std::unique_ptr<WorkerThread>>> indices_;
79
+
80
+ /// Is this index multi-threaded?
81
+ bool isThreaded_;
76
82
  };
77
83
 
78
- } // namespace
84
+ } // namespace faiss
79
85
 
80
86
  #include <faiss/impl/ThreadedIndex-inl.h>