faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -12,7 +12,6 @@
12
12
  #include <faiss/IndexIVF.h>
13
13
  #include <faiss/impl/AuxIndexStructures.h>
14
14
 
15
-
16
15
  namespace faiss {
17
16
 
18
17
  /**
@@ -22,15 +21,14 @@ namespace faiss {
22
21
  */
23
22
 
24
23
  struct ScalarQuantizer {
25
-
26
24
  enum QuantizerType {
27
- QT_8bit, ///< 8 bits per component
28
- QT_4bit, ///< 4 bits per component
29
- QT_8bit_uniform, ///< same, shared range for all dimensions
25
+ QT_8bit, ///< 8 bits per component
26
+ QT_4bit, ///< 4 bits per component
27
+ QT_8bit_uniform, ///< same, shared range for all dimensions
30
28
  QT_4bit_uniform,
31
29
  QT_fp16,
32
- QT_8bit_direct, ///< fast indexing of uint8s
33
- QT_6bit, ///< 6 bits per component
30
+ QT_8bit_direct, ///< fast indexing of uint8s
31
+ QT_6bit, ///< 6 bits per component
34
32
  };
35
33
 
36
34
  QuantizerType qtype;
@@ -41,10 +39,10 @@ struct ScalarQuantizer {
41
39
 
42
40
  // rangestat_arg.
43
41
  enum RangeStat {
44
- RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)]
45
- RS_meanstd, ///< [mean - std * rs, mean + std * rs]
46
- RS_quantiles, ///< [Q(rs), Q(1-rs)]
47
- RS_optim, ///< alternate optimization of reconstruction error
42
+ RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)]
43
+ RS_meanstd, ///< [mean - std * rs, mean + std * rs]
44
+ RS_quantiles, ///< [Q(rs), Q(1-rs)]
45
+ RS_optim, ///< alternate optimization of reconstruction error
48
46
  };
49
47
 
50
48
  RangeStat rangestat;
@@ -62,29 +60,35 @@ struct ScalarQuantizer {
62
60
  /// trained values (including the range)
63
61
  std::vector<float> trained;
64
62
 
65
- ScalarQuantizer (size_t d, QuantizerType qtype);
66
- ScalarQuantizer ();
63
+ ScalarQuantizer(size_t d, QuantizerType qtype);
64
+ ScalarQuantizer();
67
65
 
68
66
  /// updates internal values based on qtype and d
69
- void set_derived_sizes ();
67
+ void set_derived_sizes();
70
68
 
71
- void train (size_t n, const float *x);
69
+ void train(size_t n, const float* x);
72
70
 
73
71
  /// Used by an IVF index to train based on the residuals
74
- void train_residual (size_t n,
75
- const float *x,
76
- Index *quantizer,
77
- bool by_residual,
78
- bool verbose);
79
-
80
- /// same as compute_code for several vectors
81
- void compute_codes (const float * x,
82
- uint8_t * codes,
83
- size_t n) const ;
84
-
85
- /// decode a vector from a given code (or n vectors if third argument)
86
- void decode (const uint8_t *code, float *x, size_t n) const;
87
-
72
+ void train_residual(
73
+ size_t n,
74
+ const float* x,
75
+ Index* quantizer,
76
+ bool by_residual,
77
+ bool verbose);
78
+
79
+ /** Encode a set of vectors
80
+ *
81
+ * @param x vectors to encode, size n * d
82
+ * @param codes output codes, size n * code_size
83
+ */
84
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const;
85
+
86
+ /** Decode a set of vectors
87
+ *
88
+ * @param codes codes to decode, size n * code_size
89
+ * @param x output vectors, size n * d
90
+ */
91
+ void decode(const uint8_t* code, float* x, size_t n) const;
88
92
 
89
93
  /*****************************************************
90
94
  * Objects that provide methods for encoding/decoding, distance
@@ -93,34 +97,32 @@ struct ScalarQuantizer {
93
97
 
94
98
  struct Quantizer {
95
99
  // encodes one vector. Assumes code is filled with 0s on input!
96
- virtual void encode_vector(const float *x, uint8_t *code) const = 0;
97
- virtual void decode_vector(const uint8_t *code, float *x) const = 0;
100
+ virtual void encode_vector(const float* x, uint8_t* code) const = 0;
101
+ virtual void decode_vector(const uint8_t* code, float* x) const = 0;
98
102
 
99
103
  virtual ~Quantizer() {}
100
104
  };
101
105
 
102
- Quantizer * select_quantizer() const;
106
+ Quantizer* select_quantizer() const;
103
107
 
104
- struct SQDistanceComputer: DistanceComputer {
105
-
106
- const float *q;
107
- const uint8_t *codes;
108
+ struct SQDistanceComputer : DistanceComputer {
109
+ const float* q;
110
+ const uint8_t* codes;
108
111
  size_t code_size;
109
112
 
110
- SQDistanceComputer (): q(nullptr), codes (nullptr), code_size (0)
111
- {}
113
+ SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
112
114
 
115
+ virtual float query_to_code(const uint8_t* code) const = 0;
113
116
  };
114
117
 
115
- SQDistanceComputer *get_distance_computer (MetricType metric = METRIC_L2)
116
- const;
117
-
118
- InvertedListScanner *select_InvertedListScanner
119
- (MetricType mt, const Index *quantizer, bool store_pairs,
120
- bool by_residual=false) const;
118
+ SQDistanceComputer* get_distance_computer(
119
+ MetricType metric = METRIC_L2) const;
121
120
 
121
+ InvertedListScanner* select_InvertedListScanner(
122
+ MetricType mt,
123
+ const Index* quantizer,
124
+ bool store_pairs,
125
+ bool by_residual = false) const;
122
126
  };
123
127
 
124
-
125
-
126
128
  } // namespace faiss
@@ -13,180 +13,178 @@ namespace faiss {
13
13
 
14
14
  template <typename IndexT>
15
15
  ThreadedIndex<IndexT>::ThreadedIndex(bool threaded)
16
- // 0 is default dimension
17
- : ThreadedIndex(0, threaded) {
18
- }
16
+ // 0 is default dimension
17
+ : ThreadedIndex(0, threaded) {}
19
18
 
20
19
  template <typename IndexT>
21
20
  ThreadedIndex<IndexT>::ThreadedIndex(int d, bool threaded)
22
- : IndexT(d),
23
- own_fields(false),
24
- isThreaded_(threaded) {
25
- }
21
+ : IndexT(d), own_fields(false), isThreaded_(threaded) {}
26
22
 
27
23
  template <typename IndexT>
28
24
  ThreadedIndex<IndexT>::~ThreadedIndex() {
29
- for (auto& p : indices_) {
30
- if (isThreaded_) {
31
- // should have worker thread
32
- FAISS_ASSERT((bool) p.second);
33
-
34
- // This will also flush all pending work
35
- p.second->stop();
36
- p.second->waitForThreadExit();
37
- } else {
38
- // should not have worker thread
39
- FAISS_ASSERT(!(bool) p.second);
40
- }
41
-
42
- if (own_fields) {
43
- delete p.first;
25
+ for (auto& p : indices_) {
26
+ if (isThreaded_) {
27
+ // should have worker thread
28
+ FAISS_ASSERT((bool)p.second);
29
+
30
+ // This will also flush all pending work
31
+ p.second->stop();
32
+ p.second->waitForThreadExit();
33
+ } else {
34
+ // should not have worker thread
35
+ FAISS_ASSERT(!(bool)p.second);
36
+ }
37
+
38
+ if (own_fields) {
39
+ delete p.first;
40
+ }
44
41
  }
45
- }
46
42
  }
47
43
 
48
44
  template <typename IndexT>
49
45
  void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
50
- // We inherit the dimension from the first index added to us if we don't have
51
- // a set dimension
52
- if (indices_.empty() && this->d == 0) {
53
- this->d = index->d;
54
- }
55
-
56
- // The new index must match our set dimension
57
- FAISS_THROW_IF_NOT_FMT(this->d == index->d,
58
- "addIndex: dimension mismatch for "
59
- "newly added index; expecting dim %d, "
60
- "new index has dim %d",
61
- this->d, index->d);
62
-
63
- if (!indices_.empty()) {
64
- auto& existing = indices_.front().first;
65
-
66
- FAISS_THROW_IF_NOT_MSG(index->metric_type == existing->metric_type,
67
- "addIndex: newly added index is "
68
- "of different metric type than old index");
69
-
70
- // Make sure this index is not duplicated
71
- for (auto& p : indices_) {
72
- FAISS_THROW_IF_NOT_MSG(p.first != index,
73
- "addIndex: attempting to add index "
74
- "that is already in the collection");
46
+ // We inherit the dimension from the first index added to us if we don't
47
+ // have a set dimension
48
+ if (indices_.empty() && this->d == 0) {
49
+ this->d = index->d;
50
+ }
51
+
52
+ // The new index must match our set dimension
53
+ FAISS_THROW_IF_NOT_FMT(
54
+ this->d == index->d,
55
+ "addIndex: dimension mismatch for "
56
+ "newly added index; expecting dim %d, "
57
+ "new index has dim %d",
58
+ this->d,
59
+ index->d);
60
+
61
+ if (!indices_.empty()) {
62
+ auto& existing = indices_.front().first;
63
+
64
+ FAISS_THROW_IF_NOT_MSG(
65
+ index->metric_type == existing->metric_type,
66
+ "addIndex: newly added index is "
67
+ "of different metric type than old index");
68
+
69
+ // Make sure this index is not duplicated
70
+ for (auto& p : indices_) {
71
+ FAISS_THROW_IF_NOT_MSG(
72
+ p.first != index,
73
+ "addIndex: attempting to add index "
74
+ "that is already in the collection");
75
+ }
75
76
  }
76
- }
77
77
 
78
- indices_.emplace_back(
79
- std::make_pair(
80
- index,
81
- std::unique_ptr<WorkerThread>(isThreaded_ ?
82
- new WorkerThread : nullptr)));
78
+ indices_.emplace_back(std::make_pair(
79
+ index,
80
+ std::unique_ptr<WorkerThread>(
81
+ isThreaded_ ? new WorkerThread : nullptr)));
83
82
 
84
- onAfterAddIndex(index);
83
+ onAfterAddIndex(index);
85
84
  }
86
85
 
87
86
  template <typename IndexT>
88
87
  void ThreadedIndex<IndexT>::removeIndex(IndexT* index) {
89
- for (auto it = indices_.begin(); it != indices_.end(); ++it) {
90
- if (it->first == index) {
91
- // This is our index; stop the worker thread before removing it,
92
- // to ensure that it has finished before function exit
93
- if (isThreaded_) {
94
- // should have worker thread
95
- FAISS_ASSERT((bool) it->second);
96
- it->second->stop();
97
- it->second->waitForThreadExit();
98
- } else {
99
- // should not have worker thread
100
- FAISS_ASSERT(!(bool) it->second);
101
- }
102
-
103
- indices_.erase(it);
104
- onAfterRemoveIndex(index);
105
-
106
- if (own_fields) {
107
- delete index;
108
- }
109
-
110
- return;
88
+ for (auto it = indices_.begin(); it != indices_.end(); ++it) {
89
+ if (it->first == index) {
90
+ // This is our index; stop the worker thread before removing it,
91
+ // to ensure that it has finished before function exit
92
+ if (isThreaded_) {
93
+ // should have worker thread
94
+ FAISS_ASSERT((bool)it->second);
95
+ it->second->stop();
96
+ it->second->waitForThreadExit();
97
+ } else {
98
+ // should not have worker thread
99
+ FAISS_ASSERT(!(bool)it->second);
100
+ }
101
+
102
+ indices_.erase(it);
103
+ onAfterRemoveIndex(index);
104
+
105
+ if (own_fields) {
106
+ delete index;
107
+ }
108
+
109
+ return;
110
+ }
111
111
  }
112
- }
113
112
 
114
- // could not find our index
115
- FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
113
+ // could not find our index
114
+ FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
116
115
  }
117
116
 
118
117
  template <typename IndexT>
119
118
  void ThreadedIndex<IndexT>::runOnIndex(std::function<void(int, IndexT*)> f) {
120
- if (isThreaded_) {
121
- std::vector<std::future<bool>> v;
122
-
123
- for (int i = 0; i < this->indices_.size(); ++i) {
124
- auto& p = this->indices_[i];
125
- auto indexPtr = p.first;
126
- v.emplace_back(p.second->add([f, i, indexPtr](){ f(i, indexPtr); }));
127
- }
119
+ if (isThreaded_) {
120
+ std::vector<std::future<bool>> v;
128
121
 
129
- waitAndHandleFutures(v);
130
- } else {
131
- // Multiple exceptions may be thrown; gather them as we encounter them,
132
- // while letting everything else run to completion
133
- std::vector<std::pair<int, std::exception_ptr>> exceptions;
122
+ for (int i = 0; i < this->indices_.size(); ++i) {
123
+ auto& p = this->indices_[i];
124
+ auto indexPtr = p.first;
125
+ v.emplace_back(
126
+ p.second->add([f, i, indexPtr]() { f(i, indexPtr); }));
127
+ }
134
128
 
135
- for (int i = 0; i < this->indices_.size(); ++i) {
136
- auto& p = this->indices_[i];
137
- try {
138
- f(i, p.first);
139
- } catch (...) {
140
- exceptions.emplace_back(std::make_pair(i, std::current_exception()));
141
- }
129
+ waitAndHandleFutures(v);
130
+ } else {
131
+ // Multiple exceptions may be thrown; gather them as we encounter them,
132
+ // while letting everything else run to completion
133
+ std::vector<std::pair<int, std::exception_ptr>> exceptions;
134
+
135
+ for (int i = 0; i < this->indices_.size(); ++i) {
136
+ auto& p = this->indices_[i];
137
+ try {
138
+ f(i, p.first);
139
+ } catch (...) {
140
+ exceptions.emplace_back(
141
+ std::make_pair(i, std::current_exception()));
142
+ }
143
+ }
144
+
145
+ handleExceptions(exceptions);
142
146
  }
143
-
144
- handleExceptions(exceptions);
145
- }
146
147
  }
147
148
 
148
149
  template <typename IndexT>
149
150
  void ThreadedIndex<IndexT>::runOnIndex(
150
- std::function<void(int, const IndexT*)> f) const {
151
- const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
152
- [f](int i, IndexT* idx){ f(i, idx); });
151
+ std::function<void(int, const IndexT*)> f) const {
152
+ const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
153
+ [f](int i, IndexT* idx) { f(i, idx); });
153
154
  }
154
155
 
155
156
  template <typename IndexT>
156
157
  void ThreadedIndex<IndexT>::reset() {
157
- runOnIndex([](int, IndexT* index){ index->reset(); });
158
- this->ntotal = 0;
159
- this->is_trained = false;
158
+ runOnIndex([](int, IndexT* index) { index->reset(); });
159
+ this->ntotal = 0;
160
+ this->is_trained = false;
160
161
  }
161
162
 
162
163
  template <typename IndexT>
163
- void
164
- ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {
165
- }
164
+ void ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {}
166
165
 
167
166
  template <typename IndexT>
168
- void
169
- ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {
170
- }
167
+ void ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {}
171
168
 
172
169
  template <typename IndexT>
173
- void
174
- ThreadedIndex<IndexT>::waitAndHandleFutures(std::vector<std::future<bool>>& v) {
175
- // Blocking wait for completion for all of the indices, capturing any
176
- // exceptions that are generated
177
- std::vector<std::pair<int, std::exception_ptr>> exceptions;
178
-
179
- for (int i = 0; i < v.size(); ++i) {
180
- auto& fut = v[i];
181
-
182
- try {
183
- fut.get();
184
- } catch (...) {
185
- exceptions.emplace_back(std::make_pair(i, std::current_exception()));
170
+ void ThreadedIndex<IndexT>::waitAndHandleFutures(
171
+ std::vector<std::future<bool>>& v) {
172
+ // Blocking wait for completion for all of the indices, capturing any
173
+ // exceptions that are generated
174
+ std::vector<std::pair<int, std::exception_ptr>> exceptions;
175
+
176
+ for (int i = 0; i < v.size(); ++i) {
177
+ auto& fut = v[i];
178
+
179
+ try {
180
+ fut.get();
181
+ } catch (...) {
182
+ exceptions.emplace_back(
183
+ std::make_pair(i, std::current_exception()));
184
+ }
186
185
  }
187
- }
188
186
 
189
- handleExceptions(exceptions);
187
+ handleExceptions(exceptions);
190
188
  }
191
189
 
192
- } // namespace
190
+ } // namespace faiss