faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -16,7 +16,6 @@
16
16
  #include <faiss/IndexIVF.h>
17
17
  #include <faiss/impl/ScalarQuantizer.h>
18
18
 
19
-
20
19
  namespace faiss {
21
20
 
22
21
  /**
@@ -25,10 +24,7 @@ namespace faiss {
25
24
  * (default).
26
25
  */
27
26
 
28
-
29
-
30
-
31
- struct IndexScalarQuantizer: Index {
27
+ struct IndexScalarQuantizer : Index {
32
28
  /// Used to encode the vectors
33
29
  ScalarQuantizer sq;
34
30
 
@@ -43,22 +39,23 @@ struct IndexScalarQuantizer: Index {
43
39
  * @param M number of subquantizers
44
40
  * @param nbits number of bit per subvector index
45
41
  */
46
- IndexScalarQuantizer (int d,
47
- ScalarQuantizer::QuantizerType qtype,
48
- MetricType metric = METRIC_L2);
42
+ IndexScalarQuantizer(
43
+ int d,
44
+ ScalarQuantizer::QuantizerType qtype,
45
+ MetricType metric = METRIC_L2);
49
46
 
50
- IndexScalarQuantizer ();
47
+ IndexScalarQuantizer();
51
48
 
52
49
  void train(idx_t n, const float* x) override;
53
50
 
54
51
  void add(idx_t n, const float* x) override;
55
52
 
56
53
  void search(
57
- idx_t n,
58
- const float* x,
59
- idx_t k,
60
- float* distances,
61
- idx_t* labels) const override;
54
+ idx_t n,
55
+ const float* x,
56
+ idx_t k,
57
+ float* distances,
58
+ idx_t* labels) const override;
62
59
 
63
60
  void reset() override;
64
61
 
@@ -66,62 +63,61 @@ struct IndexScalarQuantizer: Index {
66
63
 
67
64
  void reconstruct(idx_t key, float* recons) const override;
68
65
 
69
- DistanceComputer *get_distance_computer () const override;
66
+ DistanceComputer* get_distance_computer() const override;
70
67
 
71
68
  /* standalone codec interface */
72
- size_t sa_code_size () const override;
73
-
74
- void sa_encode (idx_t n, const float *x,
75
- uint8_t *bytes) const override;
76
-
77
- void sa_decode (idx_t n, const uint8_t *bytes,
78
- float *x) const override;
69
+ size_t sa_code_size() const override;
79
70
 
71
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
80
72
 
73
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
81
74
  };
82
75
 
83
-
84
- /** An IVF implementation where the components of the residuals are
76
+ /** An IVF implementation where the components of the residuals are
85
77
  * encoded with a scalar quantizer. All distance computations
86
78
  * are asymmetric, so the encoded vectors are decoded and approximate
87
79
  * distances are computed.
88
80
  */
89
81
 
90
- struct IndexIVFScalarQuantizer: IndexIVF {
82
+ struct IndexIVFScalarQuantizer : IndexIVF {
91
83
  ScalarQuantizer sq;
92
84
  bool by_residual;
93
85
 
94
- IndexIVFScalarQuantizer(Index *quantizer, size_t d, size_t nlist,
95
- ScalarQuantizer::QuantizerType qtype,
96
- MetricType metric = METRIC_L2,
97
- bool encode_residual = true);
86
+ IndexIVFScalarQuantizer(
87
+ Index* quantizer,
88
+ size_t d,
89
+ size_t nlist,
90
+ ScalarQuantizer::QuantizerType qtype,
91
+ MetricType metric = METRIC_L2,
92
+ bool encode_residual = true);
98
93
 
99
94
  IndexIVFScalarQuantizer();
100
95
 
101
96
  void train_residual(idx_t n, const float* x) override;
102
97
 
103
- void encode_vectors(idx_t n, const float* x,
104
- const idx_t *list_nos,
105
- uint8_t * codes,
106
- bool include_listnos=false) const override;
107
-
108
- void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
98
+ void encode_vectors(
99
+ idx_t n,
100
+ const float* x,
101
+ const idx_t* list_nos,
102
+ uint8_t* codes,
103
+ bool include_listnos = false) const override;
109
104
 
110
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
111
- const override;
105
+ void add_core(
106
+ idx_t n,
107
+ const float* x,
108
+ const idx_t* xids,
109
+ const idx_t* precomputed_idx) override;
112
110
 
111
+ InvertedListScanner* get_InvertedListScanner(
112
+ bool store_pairs) const override;
113
113
 
114
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
115
- float* recons) const override;
114
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
115
+ const override;
116
116
 
117
117
  /* standalone codec interface */
118
- void sa_decode (idx_t n, const uint8_t *bytes,
119
- float *x) const override;
120
-
118
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
121
119
  };
122
120
 
123
-
124
- }
125
-
121
+ } // namespace faiss
126
122
 
127
123
  #endif
@@ -24,18 +24,17 @@ namespace {
24
24
 
25
25
  typedef Index::idx_t idx_t;
26
26
 
27
-
28
27
  // add translation to all valid labels
29
- void translate_labels (long n, idx_t *labels, long translation)
30
- {
31
- if (translation == 0) return;
28
+ void translate_labels(long n, idx_t* labels, long translation) {
29
+ if (translation == 0)
30
+ return;
32
31
  for (long i = 0; i < n; i++) {
33
- if(labels[i] < 0) continue;
32
+ if (labels[i] < 0)
33
+ continue;
34
34
  labels[i] += translation;
35
35
  }
36
36
  }
37
37
 
38
-
39
38
  /** merge result tables from several shards.
40
39
  * @param all_distances size nshard * n * k
41
40
  * @param all_labels idem
@@ -43,296 +42,313 @@ void translate_labels (long n, idx_t *labels, long translation)
43
42
  */
44
43
 
45
44
  template <class IndexClass, class C>
46
- void
47
- merge_tables(long n, long k, long nshard,
48
- typename IndexClass::distance_t *distances,
49
- idx_t *labels,
50
- const std::vector<typename IndexClass::distance_t>& all_distances,
51
- const std::vector<idx_t>& all_labels,
52
- const std::vector<long>& translations) {
53
- if (k == 0) {
54
- return;
55
- }
56
- using distance_t = typename IndexClass::distance_t;
57
-
58
- long stride = n * k;
45
+ void merge_tables(
46
+ long n,
47
+ long k,
48
+ long nshard,
49
+ typename IndexClass::distance_t* distances,
50
+ idx_t* labels,
51
+ const std::vector<typename IndexClass::distance_t>& all_distances,
52
+ const std::vector<idx_t>& all_labels,
53
+ const std::vector<long>& translations) {
54
+ if (k == 0) {
55
+ return;
56
+ }
57
+ using distance_t = typename IndexClass::distance_t;
58
+
59
+ long stride = n * k;
59
60
  #pragma omp parallel
60
- {
61
- std::vector<int> buf (2 * nshard);
62
- int * pointer = buf.data();
63
- int * shard_ids = pointer + nshard;
64
- std::vector<distance_t> buf2 (nshard);
65
- distance_t * heap_vals = buf2.data();
61
+ {
62
+ std::vector<int> buf(2 * nshard);
63
+ int* pointer = buf.data();
64
+ int* shard_ids = pointer + nshard;
65
+ std::vector<distance_t> buf2(nshard);
66
+ distance_t* heap_vals = buf2.data();
66
67
  #pragma omp for
67
- for (long i = 0; i < n; i++) {
68
- // the heap maps values to the shard where they are
69
- // produced.
70
- const distance_t *D_in = all_distances.data() + i * k;
71
- const idx_t *I_in = all_labels.data() + i * k;
72
- int heap_size = 0;
73
-
74
- for (long s = 0; s < nshard; s++) {
75
- pointer[s] = 0;
76
- if (I_in[stride * s] >= 0) {
77
- heap_push<C> (++heap_size, heap_vals, shard_ids,
78
- D_in[stride * s], s);
79
- }
80
- }
81
-
82
- distance_t *D = distances + i * k;
83
- idx_t *I = labels + i * k;
84
-
85
- for (int j = 0; j < k; j++) {
86
- if (heap_size == 0) {
87
- I[j] = -1;
88
- D[j] = C::neutral();
89
- } else {
90
- // pop best element
91
- int s = shard_ids[0];
92
- int & p = pointer[s];
93
- D[j] = heap_vals[0];
94
- I[j] = I_in[stride * s + p] + translations[s];
95
-
96
- heap_pop<C> (heap_size--, heap_vals, shard_ids);
97
- p++;
98
- if (p < k && I_in[stride * s + p] >= 0) {
99
- heap_push<C> (++heap_size, heap_vals, shard_ids,
100
- D_in[stride * s + p], s);
101
- }
68
+ for (long i = 0; i < n; i++) {
69
+ // the heap maps values to the shard where they are
70
+ // produced.
71
+ const distance_t* D_in = all_distances.data() + i * k;
72
+ const idx_t* I_in = all_labels.data() + i * k;
73
+ int heap_size = 0;
74
+
75
+ for (long s = 0; s < nshard; s++) {
76
+ pointer[s] = 0;
77
+ if (I_in[stride * s] >= 0) {
78
+ heap_push<C>(
79
+ ++heap_size,
80
+ heap_vals,
81
+ shard_ids,
82
+ D_in[stride * s],
83
+ s);
84
+ }
85
+ }
86
+
87
+ distance_t* D = distances + i * k;
88
+ idx_t* I = labels + i * k;
89
+
90
+ for (int j = 0; j < k; j++) {
91
+ if (heap_size == 0) {
92
+ I[j] = -1;
93
+ D[j] = C::neutral();
94
+ } else {
95
+ // pop best element
96
+ int s = shard_ids[0];
97
+ int& p = pointer[s];
98
+ D[j] = heap_vals[0];
99
+ I[j] = I_in[stride * s + p] + translations[s];
100
+
101
+ heap_pop<C>(heap_size--, heap_vals, shard_ids);
102
+ p++;
103
+ if (p < k && I_in[stride * s + p] >= 0) {
104
+ heap_push<C>(
105
+ ++heap_size,
106
+ heap_vals,
107
+ shard_ids,
108
+ D_in[stride * s + p],
109
+ s);
110
+ }
111
+ }
112
+ }
102
113
  }
103
- }
104
114
  }
105
- }
106
115
  }
107
116
 
108
117
  } // anonymous namespace
109
118
 
110
119
  template <typename IndexT>
111
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(idx_t d,
112
- bool threaded,
113
- bool successive_ids)
114
- : ThreadedIndex<IndexT>(d, threaded),
115
- successive_ids(successive_ids) {
116
- }
120
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
121
+ idx_t d,
122
+ bool threaded,
123
+ bool successive_ids)
124
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
117
125
 
118
126
  template <typename IndexT>
119
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(int d,
120
- bool threaded,
121
- bool successive_ids)
122
- : ThreadedIndex<IndexT>(d, threaded),
123
- successive_ids(successive_ids) {
124
- }
127
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
128
+ int d,
129
+ bool threaded,
130
+ bool successive_ids)
131
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
125
132
 
126
133
  template <typename IndexT>
127
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(bool threaded,
128
- bool successive_ids)
129
- : ThreadedIndex<IndexT>(threaded),
130
- successive_ids(successive_ids) {
131
- }
134
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
135
+ bool threaded,
136
+ bool successive_ids)
137
+ : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {}
132
138
 
133
139
  template <typename IndexT>
134
- void
135
- IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
136
- syncWithSubIndexes();
140
+ void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
141
+ syncWithSubIndexes();
137
142
  }
138
143
 
139
144
  template <typename IndexT>
140
- void
141
- IndexShardsTemplate<IndexT>::onAfterRemoveIndex(IndexT* index /* unused */) {
142
- syncWithSubIndexes();
145
+ void IndexShardsTemplate<IndexT>::onAfterRemoveIndex(
146
+ IndexT* index /* unused */) {
147
+ syncWithSubIndexes();
143
148
  }
144
149
 
145
150
  // FIXME: assumes that nothing is currently running on the sub-indexes, which is
146
151
  // true with the normal API, but should use the runOnIndex API instead
147
152
  template <typename IndexT>
148
- void
149
- IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
150
- if (!this->count()) {
151
- this->is_trained = false;
152
- this->ntotal = 0;
153
-
154
- return;
155
- }
156
-
157
- auto firstIndex = this->at(0);
158
- this->metric_type = firstIndex->metric_type;
159
- this->is_trained = firstIndex->is_trained;
160
- this->ntotal = firstIndex->ntotal;
161
-
162
- for (int i = 1; i < this->count(); ++i) {
163
- auto index = this->at(i);
164
- FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
165
- FAISS_THROW_IF_NOT(this->d == index->d);
166
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
167
-
168
- this->ntotal += index->ntotal;
169
- }
153
+ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
154
+ if (!this->count()) {
155
+ this->is_trained = false;
156
+ this->ntotal = 0;
157
+
158
+ return;
159
+ }
160
+
161
+ auto firstIndex = this->at(0);
162
+ this->metric_type = firstIndex->metric_type;
163
+ this->is_trained = firstIndex->is_trained;
164
+ this->ntotal = firstIndex->ntotal;
165
+
166
+ for (int i = 1; i < this->count(); ++i) {
167
+ auto index = this->at(i);
168
+ FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
169
+ FAISS_THROW_IF_NOT(this->d == index->d);
170
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
171
+
172
+ this->ntotal += index->ntotal;
173
+ }
170
174
  }
171
175
 
172
176
  // No metric_type for IndexBinary
173
177
  template <>
174
- void
175
- IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
176
- if (!this->count()) {
177
- this->is_trained = false;
178
- this->ntotal = 0;
179
-
180
- return;
181
- }
182
-
183
- auto firstIndex = this->at(0);
184
- this->is_trained = firstIndex->is_trained;
185
- this->ntotal = firstIndex->ntotal;
186
-
187
- for (int i = 1; i < this->count(); ++i) {
188
- auto index = this->at(i);
189
- FAISS_THROW_IF_NOT(this->d == index->d);
190
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
191
-
192
- this->ntotal += index->ntotal;
193
- }
178
+ void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
179
+ if (!this->count()) {
180
+ this->is_trained = false;
181
+ this->ntotal = 0;
182
+
183
+ return;
184
+ }
185
+
186
+ auto firstIndex = this->at(0);
187
+ this->is_trained = firstIndex->is_trained;
188
+ this->ntotal = firstIndex->ntotal;
189
+
190
+ for (int i = 1; i < this->count(); ++i) {
191
+ auto index = this->at(i);
192
+ FAISS_THROW_IF_NOT(this->d == index->d);
193
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
194
+
195
+ this->ntotal += index->ntotal;
196
+ }
194
197
  }
195
198
 
196
199
  template <typename IndexT>
197
- void
198
- IndexShardsTemplate<IndexT>::train(idx_t n,
199
- const component_t *x) {
200
- auto fn =
201
- [n, x](int no, IndexT *index) {
202
- if (index->verbose) {
203
- printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
- }
205
-
206
- index->train(n, x);
207
-
208
- if (index->verbose) {
209
- printf("end train shard %d\n", no);
210
- }
200
+ void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
201
+ auto fn = [n, x](int no, IndexT* index) {
202
+ if (index->verbose) {
203
+ printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
+ }
205
+
206
+ index->train(n, x);
207
+
208
+ if (index->verbose) {
209
+ printf("end train shard %d\n", no);
210
+ }
211
211
  };
212
212
 
213
- this->runOnIndex(fn);
214
- syncWithSubIndexes();
213
+ this->runOnIndex(fn);
214
+ syncWithSubIndexes();
215
215
  }
216
216
 
217
217
  template <typename IndexT>
218
- void
219
- IndexShardsTemplate<IndexT>::add(idx_t n,
220
- const component_t *x) {
221
- add_with_ids(n, x, nullptr);
218
+ void IndexShardsTemplate<IndexT>::add(idx_t n, const component_t* x) {
219
+ add_with_ids(n, x, nullptr);
222
220
  }
223
221
 
224
222
  template <typename IndexT>
225
- void
226
- IndexShardsTemplate<IndexT>::add_with_ids(idx_t n,
227
- const component_t * x,
228
- const idx_t *xids) {
229
-
230
- FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
231
- "It makes no sense to pass in ids and "
232
- "request them to be shifted");
233
-
234
- if (successive_ids) {
235
- FAISS_THROW_IF_NOT_MSG(!xids,
236
- "It makes no sense to pass in ids and "
237
- "request them to be shifted");
238
- FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
239
- "when adding to IndexShards with sucessive_ids, "
240
- "only add() in a single pass is supported");
241
- }
242
-
243
- idx_t nshard = this->count();
244
- const idx_t *ids = xids;
245
-
246
- std::vector<idx_t> aids;
247
-
248
- if (!ids && !successive_ids) {
249
- aids.resize(n);
250
-
251
- for (idx_t i = 0; i < n; i++) {
252
- aids[i] = this->ntotal + i;
223
+ void IndexShardsTemplate<IndexT>::add_with_ids(
224
+ idx_t n,
225
+ const component_t* x,
226
+ const idx_t* xids) {
227
+ FAISS_THROW_IF_NOT_MSG(
228
+ !(successive_ids && xids),
229
+ "It makes no sense to pass in ids and "
230
+ "request them to be shifted");
231
+
232
+ if (successive_ids) {
233
+ FAISS_THROW_IF_NOT_MSG(
234
+ !xids,
235
+ "It makes no sense to pass in ids and "
236
+ "request them to be shifted");
237
+ FAISS_THROW_IF_NOT_MSG(
238
+ this->ntotal == 0,
239
+ "when adding to IndexShards with sucessive_ids, "
240
+ "only add() in a single pass is supported");
253
241
  }
254
242
 
255
- ids = aids.data();
256
- }
243
+ idx_t nshard = this->count();
244
+ const idx_t* ids = xids;
245
+
246
+ std::vector<idx_t> aids;
247
+
248
+ if (!ids && !successive_ids) {
249
+ aids.resize(n);
250
+
251
+ for (idx_t i = 0; i < n; i++) {
252
+ aids[i] = this->ntotal + i;
253
+ }
254
+
255
+ ids = aids.data();
256
+ }
257
257
 
258
- size_t components_per_vec =
259
- sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
258
+ size_t components_per_vec =
259
+ sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
260
260
 
261
- auto fn =
262
- [n, ids, x, nshard, components_per_vec](int no, IndexT *index) {
263
- idx_t i0 = (idx_t) no * n / nshard;
264
- idx_t i1 = ((idx_t) no + 1) * n / nshard;
265
- auto x0 = x + i0 * components_per_vec;
261
+ auto fn = [n, ids, x, nshard, components_per_vec](int no, IndexT* index) {
262
+ idx_t i0 = (idx_t)no * n / nshard;
263
+ idx_t i1 = ((idx_t)no + 1) * n / nshard;
264
+ auto x0 = x + i0 * components_per_vec;
266
265
 
267
- if (index->verbose) {
268
- printf ("begin add shard %d on %" PRId64 " points\n", no, n);
269
- }
266
+ if (index->verbose) {
267
+ printf("begin add shard %d on %" PRId64 " points\n", no, n);
268
+ }
270
269
 
271
- if (ids) {
272
- index->add_with_ids (i1 - i0, x0, ids + i0);
273
- } else {
274
- index->add (i1 - i0, x0);
275
- }
270
+ if (ids) {
271
+ index->add_with_ids(i1 - i0, x0, ids + i0);
272
+ } else {
273
+ index->add(i1 - i0, x0);
274
+ }
276
275
 
277
- if (index->verbose) {
278
- printf ("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
279
- }
276
+ if (index->verbose) {
277
+ printf("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
278
+ }
280
279
  };
281
280
 
282
- this->runOnIndex(fn);
283
- syncWithSubIndexes();
281
+ this->runOnIndex(fn);
282
+ syncWithSubIndexes();
284
283
  }
285
284
 
286
285
  template <typename IndexT>
287
- void
288
- IndexShardsTemplate<IndexT>::search(idx_t n,
289
- const component_t *x,
290
- idx_t k,
291
- distance_t *distances,
292
- idx_t *labels) const {
293
- long nshard = this->count();
294
-
295
- std::vector<distance_t> all_distances(nshard * k * n);
296
- std::vector<idx_t> all_labels(nshard * k * n);
297
-
298
- auto fn =
299
- [n, k, x, &all_distances, &all_labels](int no, const IndexT *index) {
300
- if (index->verbose) {
301
- printf ("begin query shard %d on %" PRId64 " points\n", no, n);
302
- }
303
-
304
- index->search (n, x, k,
305
- all_distances.data() + no * k * n,
306
- all_labels.data() + no * k * n);
307
-
308
- if (index->verbose) {
309
- printf ("end query shard %d\n", no);
310
- }
286
+ void IndexShardsTemplate<IndexT>::search(
287
+ idx_t n,
288
+ const component_t* x,
289
+ idx_t k,
290
+ distance_t* distances,
291
+ idx_t* labels) const {
292
+ FAISS_THROW_IF_NOT(k > 0);
293
+
294
+ long nshard = this->count();
295
+
296
+ std::vector<distance_t> all_distances(nshard * k * n);
297
+ std::vector<idx_t> all_labels(nshard * k * n);
298
+
299
+ auto fn = [n, k, x, &all_distances, &all_labels](
300
+ int no, const IndexT* index) {
301
+ if (index->verbose) {
302
+ printf("begin query shard %d on %" PRId64 " points\n", no, n);
303
+ }
304
+
305
+ index->search(
306
+ n,
307
+ x,
308
+ k,
309
+ all_distances.data() + no * k * n,
310
+ all_labels.data() + no * k * n);
311
+
312
+ if (index->verbose) {
313
+ printf("end query shard %d\n", no);
314
+ }
311
315
  };
312
316
 
313
- this->runOnIndex(fn);
317
+ this->runOnIndex(fn);
318
+
319
+ std::vector<long> translations(nshard, 0);
314
320
 
315
- std::vector<long> translations(nshard, 0);
321
+ // Because we just called runOnIndex above, it is safe to access the
322
+ // sub-index ntotal here
323
+ if (successive_ids) {
324
+ translations[0] = 0;
316
325
 
317
- // Because we just called runOnIndex above, it is safe to access the sub-index
318
- // ntotal here
319
- if (successive_ids) {
320
- translations[0] = 0;
326
+ for (int s = 0; s + 1 < nshard; s++) {
327
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
328
+ }
329
+ }
321
330
 
322
- for (int s = 0; s + 1 < nshard; s++) {
323
- translations[s + 1] = translations[s] + this->at(s)->ntotal;
331
+ if (this->metric_type == METRIC_L2) {
332
+ merge_tables<IndexT, CMin<distance_t, int>>(
333
+ n,
334
+ k,
335
+ nshard,
336
+ distances,
337
+ labels,
338
+ all_distances,
339
+ all_labels,
340
+ translations);
341
+ } else {
342
+ merge_tables<IndexT, CMax<distance_t, int>>(
343
+ n,
344
+ k,
345
+ nshard,
346
+ distances,
347
+ labels,
348
+ all_distances,
349
+ all_labels,
350
+ translations);
324
351
  }
325
- }
326
-
327
- if (this->metric_type == METRIC_L2) {
328
- merge_tables<IndexT, CMin<distance_t, int>>(
329
- n, k, nshard, distances, labels,
330
- all_distances, all_labels, translations);
331
- } else {
332
- merge_tables<IndexT, CMax<distance_t, int>>(
333
- n, k, nshard, distances, labels,
334
- all_distances, all_labels, translations);
335
- }
336
352
  }
337
353
 
338
354
  // explicit instanciations