faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -12,7 +12,6 @@
12
12
  #include <faiss/IndexIVF.h>
13
13
  #include <faiss/impl/AuxIndexStructures.h>
14
14
 
15
-
16
15
  namespace faiss {
17
16
 
18
17
  /**
@@ -22,15 +21,14 @@ namespace faiss {
22
21
  */
23
22
 
24
23
  struct ScalarQuantizer {
25
-
26
24
  enum QuantizerType {
27
- QT_8bit, ///< 8 bits per component
28
- QT_4bit, ///< 4 bits per component
29
- QT_8bit_uniform, ///< same, shared range for all dimensions
25
+ QT_8bit, ///< 8 bits per component
26
+ QT_4bit, ///< 4 bits per component
27
+ QT_8bit_uniform, ///< same, shared range for all dimensions
30
28
  QT_4bit_uniform,
31
29
  QT_fp16,
32
- QT_8bit_direct, ///< fast indexing of uint8s
33
- QT_6bit, ///< 6 bits per component
30
+ QT_8bit_direct, ///< fast indexing of uint8s
31
+ QT_6bit, ///< 6 bits per component
34
32
  };
35
33
 
36
34
  QuantizerType qtype;
@@ -41,10 +39,10 @@ struct ScalarQuantizer {
41
39
 
42
40
  // rangestat_arg.
43
41
  enum RangeStat {
44
- RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)]
45
- RS_meanstd, ///< [mean - std * rs, mean + std * rs]
46
- RS_quantiles, ///< [Q(rs), Q(1-rs)]
47
- RS_optim, ///< alternate optimization of reconstruction error
42
+ RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)]
43
+ RS_meanstd, ///< [mean - std * rs, mean + std * rs]
44
+ RS_quantiles, ///< [Q(rs), Q(1-rs)]
45
+ RS_optim, ///< alternate optimization of reconstruction error
48
46
  };
49
47
 
50
48
  RangeStat rangestat;
@@ -62,29 +60,35 @@ struct ScalarQuantizer {
62
60
  /// trained values (including the range)
63
61
  std::vector<float> trained;
64
62
 
65
- ScalarQuantizer (size_t d, QuantizerType qtype);
66
- ScalarQuantizer ();
63
+ ScalarQuantizer(size_t d, QuantizerType qtype);
64
+ ScalarQuantizer();
67
65
 
68
66
  /// updates internal values based on qtype and d
69
- void set_derived_sizes ();
67
+ void set_derived_sizes();
70
68
 
71
- void train (size_t n, const float *x);
69
+ void train(size_t n, const float* x);
72
70
 
73
71
  /// Used by an IVF index to train based on the residuals
74
- void train_residual (size_t n,
75
- const float *x,
76
- Index *quantizer,
77
- bool by_residual,
78
- bool verbose);
79
-
80
- /// same as compute_code for several vectors
81
- void compute_codes (const float * x,
82
- uint8_t * codes,
83
- size_t n) const ;
84
-
85
- /// decode a vector from a given code (or n vectors if third argument)
86
- void decode (const uint8_t *code, float *x, size_t n) const;
87
-
72
+ void train_residual(
73
+ size_t n,
74
+ const float* x,
75
+ Index* quantizer,
76
+ bool by_residual,
77
+ bool verbose);
78
+
79
+ /** Encode a set of vectors
80
+ *
81
+ * @param x vectors to encode, size n * d
82
+ * @param codes output codes, size n * code_size
83
+ */
84
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const;
85
+
86
+ /** Decode a set of vectors
87
+ *
88
+ * @param codes codes to decode, size n * code_size
89
+ * @param x output vectors, size n * d
90
+ */
91
+ void decode(const uint8_t* code, float* x, size_t n) const;
88
92
 
89
93
  /*****************************************************
90
94
  * Objects that provide methods for encoding/decoding, distance
@@ -93,34 +97,32 @@ struct ScalarQuantizer {
93
97
 
94
98
  struct Quantizer {
95
99
  // encodes one vector. Assumes code is filled with 0s on input!
96
- virtual void encode_vector(const float *x, uint8_t *code) const = 0;
97
- virtual void decode_vector(const uint8_t *code, float *x) const = 0;
100
+ virtual void encode_vector(const float* x, uint8_t* code) const = 0;
101
+ virtual void decode_vector(const uint8_t* code, float* x) const = 0;
98
102
 
99
103
  virtual ~Quantizer() {}
100
104
  };
101
105
 
102
- Quantizer * select_quantizer() const;
106
+ Quantizer* select_quantizer() const;
103
107
 
104
- struct SQDistanceComputer: DistanceComputer {
105
-
106
- const float *q;
107
- const uint8_t *codes;
108
+ struct SQDistanceComputer : DistanceComputer {
109
+ const float* q;
110
+ const uint8_t* codes;
108
111
  size_t code_size;
109
112
 
110
- SQDistanceComputer (): q(nullptr), codes (nullptr), code_size (0)
111
- {}
113
+ SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
112
114
 
115
+ virtual float query_to_code(const uint8_t* code) const = 0;
113
116
  };
114
117
 
115
- SQDistanceComputer *get_distance_computer (MetricType metric = METRIC_L2)
116
- const;
117
-
118
- InvertedListScanner *select_InvertedListScanner
119
- (MetricType mt, const Index *quantizer, bool store_pairs,
120
- bool by_residual=false) const;
118
+ SQDistanceComputer* get_distance_computer(
119
+ MetricType metric = METRIC_L2) const;
121
120
 
121
+ InvertedListScanner* select_InvertedListScanner(
122
+ MetricType mt,
123
+ const Index* quantizer,
124
+ bool store_pairs,
125
+ bool by_residual = false) const;
122
126
  };
123
127
 
124
-
125
-
126
128
  } // namespace faiss
@@ -13,180 +13,178 @@ namespace faiss {
13
13
 
14
14
  template <typename IndexT>
15
15
  ThreadedIndex<IndexT>::ThreadedIndex(bool threaded)
16
- // 0 is default dimension
17
- : ThreadedIndex(0, threaded) {
18
- }
16
+ // 0 is default dimension
17
+ : ThreadedIndex(0, threaded) {}
19
18
 
20
19
  template <typename IndexT>
21
20
  ThreadedIndex<IndexT>::ThreadedIndex(int d, bool threaded)
22
- : IndexT(d),
23
- own_fields(false),
24
- isThreaded_(threaded) {
25
- }
21
+ : IndexT(d), own_fields(false), isThreaded_(threaded) {}
26
22
 
27
23
  template <typename IndexT>
28
24
  ThreadedIndex<IndexT>::~ThreadedIndex() {
29
- for (auto& p : indices_) {
30
- if (isThreaded_) {
31
- // should have worker thread
32
- FAISS_ASSERT((bool) p.second);
33
-
34
- // This will also flush all pending work
35
- p.second->stop();
36
- p.second->waitForThreadExit();
37
- } else {
38
- // should not have worker thread
39
- FAISS_ASSERT(!(bool) p.second);
40
- }
41
-
42
- if (own_fields) {
43
- delete p.first;
25
+ for (auto& p : indices_) {
26
+ if (isThreaded_) {
27
+ // should have worker thread
28
+ FAISS_ASSERT((bool)p.second);
29
+
30
+ // This will also flush all pending work
31
+ p.second->stop();
32
+ p.second->waitForThreadExit();
33
+ } else {
34
+ // should not have worker thread
35
+ FAISS_ASSERT(!(bool)p.second);
36
+ }
37
+
38
+ if (own_fields) {
39
+ delete p.first;
40
+ }
44
41
  }
45
- }
46
42
  }
47
43
 
48
44
  template <typename IndexT>
49
45
  void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
50
- // We inherit the dimension from the first index added to us if we don't have
51
- // a set dimension
52
- if (indices_.empty() && this->d == 0) {
53
- this->d = index->d;
54
- }
55
-
56
- // The new index must match our set dimension
57
- FAISS_THROW_IF_NOT_FMT(this->d == index->d,
58
- "addIndex: dimension mismatch for "
59
- "newly added index; expecting dim %d, "
60
- "new index has dim %d",
61
- this->d, index->d);
62
-
63
- if (!indices_.empty()) {
64
- auto& existing = indices_.front().first;
65
-
66
- FAISS_THROW_IF_NOT_MSG(index->metric_type == existing->metric_type,
67
- "addIndex: newly added index is "
68
- "of different metric type than old index");
69
-
70
- // Make sure this index is not duplicated
71
- for (auto& p : indices_) {
72
- FAISS_THROW_IF_NOT_MSG(p.first != index,
73
- "addIndex: attempting to add index "
74
- "that is already in the collection");
46
+ // We inherit the dimension from the first index added to us if we don't
47
+ // have a set dimension
48
+ if (indices_.empty() && this->d == 0) {
49
+ this->d = index->d;
50
+ }
51
+
52
+ // The new index must match our set dimension
53
+ FAISS_THROW_IF_NOT_FMT(
54
+ this->d == index->d,
55
+ "addIndex: dimension mismatch for "
56
+ "newly added index; expecting dim %d, "
57
+ "new index has dim %d",
58
+ this->d,
59
+ index->d);
60
+
61
+ if (!indices_.empty()) {
62
+ auto& existing = indices_.front().first;
63
+
64
+ FAISS_THROW_IF_NOT_MSG(
65
+ index->metric_type == existing->metric_type,
66
+ "addIndex: newly added index is "
67
+ "of different metric type than old index");
68
+
69
+ // Make sure this index is not duplicated
70
+ for (auto& p : indices_) {
71
+ FAISS_THROW_IF_NOT_MSG(
72
+ p.first != index,
73
+ "addIndex: attempting to add index "
74
+ "that is already in the collection");
75
+ }
75
76
  }
76
- }
77
77
 
78
- indices_.emplace_back(
79
- std::make_pair(
80
- index,
81
- std::unique_ptr<WorkerThread>(isThreaded_ ?
82
- new WorkerThread : nullptr)));
78
+ indices_.emplace_back(std::make_pair(
79
+ index,
80
+ std::unique_ptr<WorkerThread>(
81
+ isThreaded_ ? new WorkerThread : nullptr)));
83
82
 
84
- onAfterAddIndex(index);
83
+ onAfterAddIndex(index);
85
84
  }
86
85
 
87
86
  template <typename IndexT>
88
87
  void ThreadedIndex<IndexT>::removeIndex(IndexT* index) {
89
- for (auto it = indices_.begin(); it != indices_.end(); ++it) {
90
- if (it->first == index) {
91
- // This is our index; stop the worker thread before removing it,
92
- // to ensure that it has finished before function exit
93
- if (isThreaded_) {
94
- // should have worker thread
95
- FAISS_ASSERT((bool) it->second);
96
- it->second->stop();
97
- it->second->waitForThreadExit();
98
- } else {
99
- // should not have worker thread
100
- FAISS_ASSERT(!(bool) it->second);
101
- }
102
-
103
- indices_.erase(it);
104
- onAfterRemoveIndex(index);
105
-
106
- if (own_fields) {
107
- delete index;
108
- }
109
-
110
- return;
88
+ for (auto it = indices_.begin(); it != indices_.end(); ++it) {
89
+ if (it->first == index) {
90
+ // This is our index; stop the worker thread before removing it,
91
+ // to ensure that it has finished before function exit
92
+ if (isThreaded_) {
93
+ // should have worker thread
94
+ FAISS_ASSERT((bool)it->second);
95
+ it->second->stop();
96
+ it->second->waitForThreadExit();
97
+ } else {
98
+ // should not have worker thread
99
+ FAISS_ASSERT(!(bool)it->second);
100
+ }
101
+
102
+ indices_.erase(it);
103
+ onAfterRemoveIndex(index);
104
+
105
+ if (own_fields) {
106
+ delete index;
107
+ }
108
+
109
+ return;
110
+ }
111
111
  }
112
- }
113
112
 
114
- // could not find our index
115
- FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
113
+ // could not find our index
114
+ FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found");
116
115
  }
117
116
 
118
117
  template <typename IndexT>
119
118
  void ThreadedIndex<IndexT>::runOnIndex(std::function<void(int, IndexT*)> f) {
120
- if (isThreaded_) {
121
- std::vector<std::future<bool>> v;
122
-
123
- for (int i = 0; i < this->indices_.size(); ++i) {
124
- auto& p = this->indices_[i];
125
- auto indexPtr = p.first;
126
- v.emplace_back(p.second->add([f, i, indexPtr](){ f(i, indexPtr); }));
127
- }
119
+ if (isThreaded_) {
120
+ std::vector<std::future<bool>> v;
128
121
 
129
- waitAndHandleFutures(v);
130
- } else {
131
- // Multiple exceptions may be thrown; gather them as we encounter them,
132
- // while letting everything else run to completion
133
- std::vector<std::pair<int, std::exception_ptr>> exceptions;
122
+ for (int i = 0; i < this->indices_.size(); ++i) {
123
+ auto& p = this->indices_[i];
124
+ auto indexPtr = p.first;
125
+ v.emplace_back(
126
+ p.second->add([f, i, indexPtr]() { f(i, indexPtr); }));
127
+ }
134
128
 
135
- for (int i = 0; i < this->indices_.size(); ++i) {
136
- auto& p = this->indices_[i];
137
- try {
138
- f(i, p.first);
139
- } catch (...) {
140
- exceptions.emplace_back(std::make_pair(i, std::current_exception()));
141
- }
129
+ waitAndHandleFutures(v);
130
+ } else {
131
+ // Multiple exceptions may be thrown; gather them as we encounter them,
132
+ // while letting everything else run to completion
133
+ std::vector<std::pair<int, std::exception_ptr>> exceptions;
134
+
135
+ for (int i = 0; i < this->indices_.size(); ++i) {
136
+ auto& p = this->indices_[i];
137
+ try {
138
+ f(i, p.first);
139
+ } catch (...) {
140
+ exceptions.emplace_back(
141
+ std::make_pair(i, std::current_exception()));
142
+ }
143
+ }
144
+
145
+ handleExceptions(exceptions);
142
146
  }
143
-
144
- handleExceptions(exceptions);
145
- }
146
147
  }
147
148
 
148
149
  template <typename IndexT>
149
150
  void ThreadedIndex<IndexT>::runOnIndex(
150
- std::function<void(int, const IndexT*)> f) const {
151
- const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
152
- [f](int i, IndexT* idx){ f(i, idx); });
151
+ std::function<void(int, const IndexT*)> f) const {
152
+ const_cast<ThreadedIndex<IndexT>*>(this)->runOnIndex(
153
+ [f](int i, IndexT* idx) { f(i, idx); });
153
154
  }
154
155
 
155
156
  template <typename IndexT>
156
157
  void ThreadedIndex<IndexT>::reset() {
157
- runOnIndex([](int, IndexT* index){ index->reset(); });
158
- this->ntotal = 0;
159
- this->is_trained = false;
158
+ runOnIndex([](int, IndexT* index) { index->reset(); });
159
+ this->ntotal = 0;
160
+ this->is_trained = false;
160
161
  }
161
162
 
162
163
  template <typename IndexT>
163
- void
164
- ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {
165
- }
164
+ void ThreadedIndex<IndexT>::onAfterAddIndex(IndexT* index) {}
166
165
 
167
166
  template <typename IndexT>
168
- void
169
- ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {
170
- }
167
+ void ThreadedIndex<IndexT>::onAfterRemoveIndex(IndexT* index) {}
171
168
 
172
169
  template <typename IndexT>
173
- void
174
- ThreadedIndex<IndexT>::waitAndHandleFutures(std::vector<std::future<bool>>& v) {
175
- // Blocking wait for completion for all of the indices, capturing any
176
- // exceptions that are generated
177
- std::vector<std::pair<int, std::exception_ptr>> exceptions;
178
-
179
- for (int i = 0; i < v.size(); ++i) {
180
- auto& fut = v[i];
181
-
182
- try {
183
- fut.get();
184
- } catch (...) {
185
- exceptions.emplace_back(std::make_pair(i, std::current_exception()));
170
+ void ThreadedIndex<IndexT>::waitAndHandleFutures(
171
+ std::vector<std::future<bool>>& v) {
172
+ // Blocking wait for completion for all of the indices, capturing any
173
+ // exceptions that are generated
174
+ std::vector<std::pair<int, std::exception_ptr>> exceptions;
175
+
176
+ for (int i = 0; i < v.size(); ++i) {
177
+ auto& fut = v[i];
178
+
179
+ try {
180
+ fut.get();
181
+ } catch (...) {
182
+ exceptions.emplace_back(
183
+ std::make_pair(i, std::current_exception()));
184
+ }
186
185
  }
187
- }
188
186
 
189
- handleExceptions(exceptions);
187
+ handleExceptions(exceptions);
190
188
  }
191
189
 
192
- } // namespace
190
+ } // namespace faiss
@@ -19,62 +19,68 @@ namespace faiss {
19
19
  /// The interface to this class itself is not thread safe
20
20
  template <typename IndexT>
21
21
  class ThreadedIndex : public IndexT {
22
- public:
23
- explicit ThreadedIndex(bool threaded);
24
- explicit ThreadedIndex(int d, bool threaded);
25
-
26
- ~ThreadedIndex() override;
27
-
28
- /// override an index that is managed by ourselves.
29
- /// WARNING: once an index is added, it becomes unsafe to touch it from any
30
- /// other thread than that on which is managing it, until we are shut
31
- /// down. Use runOnIndex to perform work on it instead.
32
- void addIndex(IndexT* index);
33
-
34
- /// Remove an index that is managed by ourselves.
35
- /// This will flush all pending work on that index, and then shut
36
- /// down its managing thread, and will remove the index.
37
- void removeIndex(IndexT* index);
38
-
39
- /// Run a function on all indices, in the thread that the index is
40
- /// managed in.
41
- /// Function arguments are (index in collection, index pointer)
42
- void runOnIndex(std::function<void(int, IndexT*)> f);
43
- void runOnIndex(std::function<void(int, const IndexT*)> f) const;
44
-
45
- /// faiss::Index API
46
- /// All indices receive the same call
47
- void reset() override;
48
-
49
- /// Returns the number of sub-indices
50
- int count() const { return indices_.size(); }
51
-
52
- /// Returns the i-th sub-index
53
- IndexT* at(int i) { return indices_[i].first; }
54
-
55
- /// Returns the i-th sub-index (const version)
56
- const IndexT* at(int i) const { return indices_[i].first; }
57
-
58
- /// Whether or not we are responsible for deleting our contained indices
59
- bool own_fields;
60
-
61
- protected:
62
- /// Called just after an index is added
63
- virtual void onAfterAddIndex(IndexT* index);
64
-
65
- /// Called just after an index is removed
66
- virtual void onAfterRemoveIndex(IndexT* index);
67
-
68
- protected:
69
- static void waitAndHandleFutures(std::vector<std::future<bool>>& v);
70
-
71
- /// Collection of Index instances, with their managing worker thread if any
72
- std::vector<std::pair<IndexT*, std::unique_ptr<WorkerThread>>> indices_;
73
-
74
- /// Is this index multi-threaded?
75
- bool isThreaded_;
22
+ public:
23
+ explicit ThreadedIndex(bool threaded);
24
+ explicit ThreadedIndex(int d, bool threaded);
25
+
26
+ ~ThreadedIndex() override;
27
+
28
+ /// override an index that is managed by ourselves.
29
+ /// WARNING: once an index is added, it becomes unsafe to touch it from any
30
+ /// other thread than that on which is managing it, until we are shut
31
+ /// down. Use runOnIndex to perform work on it instead.
32
+ void addIndex(IndexT* index);
33
+
34
+ /// Remove an index that is managed by ourselves.
35
+ /// This will flush all pending work on that index, and then shut
36
+ /// down its managing thread, and will remove the index.
37
+ void removeIndex(IndexT* index);
38
+
39
+ /// Run a function on all indices, in the thread that the index is
40
+ /// managed in.
41
+ /// Function arguments are (index in collection, index pointer)
42
+ void runOnIndex(std::function<void(int, IndexT*)> f);
43
+ void runOnIndex(std::function<void(int, const IndexT*)> f) const;
44
+
45
+ /// faiss::Index API
46
+ /// All indices receive the same call
47
+ void reset() override;
48
+
49
+ /// Returns the number of sub-indices
50
+ int count() const {
51
+ return indices_.size();
52
+ }
53
+
54
+ /// Returns the i-th sub-index
55
+ IndexT* at(int i) {
56
+ return indices_[i].first;
57
+ }
58
+
59
+ /// Returns the i-th sub-index (const version)
60
+ const IndexT* at(int i) const {
61
+ return indices_[i].first;
62
+ }
63
+
64
+ /// Whether or not we are responsible for deleting our contained indices
65
+ bool own_fields;
66
+
67
+ protected:
68
+ /// Called just after an index is added
69
+ virtual void onAfterAddIndex(IndexT* index);
70
+
71
+ /// Called just after an index is removed
72
+ virtual void onAfterRemoveIndex(IndexT* index);
73
+
74
+ protected:
75
+ static void waitAndHandleFutures(std::vector<std::future<bool>>& v);
76
+
77
+ /// Collection of Index instances, with their managing worker thread if any
78
+ std::vector<std::pair<IndexT*, std::unique_ptr<WorkerThread>>> indices_;
79
+
80
+ /// Is this index multi-threaded?
81
+ bool isThreaded_;
76
82
  };
77
83
 
78
- } // namespace
84
+ } // namespace faiss
79
85
 
80
86
  #include <faiss/impl/ThreadedIndex-inl.h>