faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -16,7 +16,6 @@
16
16
  #include <faiss/IndexIVF.h>
17
17
  #include <faiss/impl/ScalarQuantizer.h>
18
18
 
19
-
20
19
  namespace faiss {
21
20
 
22
21
  /**
@@ -25,10 +24,7 @@ namespace faiss {
25
24
  * (default).
26
25
  */
27
26
 
28
-
29
-
30
-
31
- struct IndexScalarQuantizer: Index {
27
+ struct IndexScalarQuantizer : Index {
32
28
  /// Used to encode the vectors
33
29
  ScalarQuantizer sq;
34
30
 
@@ -43,22 +39,23 @@ struct IndexScalarQuantizer: Index {
43
39
  * @param M number of subquantizers
44
40
  * @param nbits number of bit per subvector index
45
41
  */
46
- IndexScalarQuantizer (int d,
47
- ScalarQuantizer::QuantizerType qtype,
48
- MetricType metric = METRIC_L2);
42
+ IndexScalarQuantizer(
43
+ int d,
44
+ ScalarQuantizer::QuantizerType qtype,
45
+ MetricType metric = METRIC_L2);
49
46
 
50
- IndexScalarQuantizer ();
47
+ IndexScalarQuantizer();
51
48
 
52
49
  void train(idx_t n, const float* x) override;
53
50
 
54
51
  void add(idx_t n, const float* x) override;
55
52
 
56
53
  void search(
57
- idx_t n,
58
- const float* x,
59
- idx_t k,
60
- float* distances,
61
- idx_t* labels) const override;
54
+ idx_t n,
55
+ const float* x,
56
+ idx_t k,
57
+ float* distances,
58
+ idx_t* labels) const override;
62
59
 
63
60
  void reset() override;
64
61
 
@@ -66,62 +63,61 @@ struct IndexScalarQuantizer: Index {
66
63
 
67
64
  void reconstruct(idx_t key, float* recons) const override;
68
65
 
69
- DistanceComputer *get_distance_computer () const override;
66
+ DistanceComputer* get_distance_computer() const override;
70
67
 
71
68
  /* standalone codec interface */
72
- size_t sa_code_size () const override;
73
-
74
- void sa_encode (idx_t n, const float *x,
75
- uint8_t *bytes) const override;
76
-
77
- void sa_decode (idx_t n, const uint8_t *bytes,
78
- float *x) const override;
69
+ size_t sa_code_size() const override;
79
70
 
71
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
80
72
 
73
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
81
74
  };
82
75
 
83
-
84
- /** An IVF implementation where the components of the residuals are
76
+ /** An IVF implementation where the components of the residuals are
85
77
  * encoded with a scalar quantizer. All distance computations
86
78
  * are asymmetric, so the encoded vectors are decoded and approximate
87
79
  * distances are computed.
88
80
  */
89
81
 
90
- struct IndexIVFScalarQuantizer: IndexIVF {
82
+ struct IndexIVFScalarQuantizer : IndexIVF {
91
83
  ScalarQuantizer sq;
92
84
  bool by_residual;
93
85
 
94
- IndexIVFScalarQuantizer(Index *quantizer, size_t d, size_t nlist,
95
- ScalarQuantizer::QuantizerType qtype,
96
- MetricType metric = METRIC_L2,
97
- bool encode_residual = true);
86
+ IndexIVFScalarQuantizer(
87
+ Index* quantizer,
88
+ size_t d,
89
+ size_t nlist,
90
+ ScalarQuantizer::QuantizerType qtype,
91
+ MetricType metric = METRIC_L2,
92
+ bool encode_residual = true);
98
93
 
99
94
  IndexIVFScalarQuantizer();
100
95
 
101
96
  void train_residual(idx_t n, const float* x) override;
102
97
 
103
- void encode_vectors(idx_t n, const float* x,
104
- const idx_t *list_nos,
105
- uint8_t * codes,
106
- bool include_listnos=false) const override;
107
-
108
- void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
98
+ void encode_vectors(
99
+ idx_t n,
100
+ const float* x,
101
+ const idx_t* list_nos,
102
+ uint8_t* codes,
103
+ bool include_listnos = false) const override;
109
104
 
110
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
111
- const override;
105
+ void add_core(
106
+ idx_t n,
107
+ const float* x,
108
+ const idx_t* xids,
109
+ const idx_t* precomputed_idx) override;
112
110
 
111
+ InvertedListScanner* get_InvertedListScanner(
112
+ bool store_pairs) const override;
113
113
 
114
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
115
- float* recons) const override;
114
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
115
+ const override;
116
116
 
117
117
  /* standalone codec interface */
118
- void sa_decode (idx_t n, const uint8_t *bytes,
119
- float *x) const override;
120
-
118
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
121
119
  };
122
120
 
123
-
124
- }
125
-
121
+ } // namespace faiss
126
122
 
127
123
  #endif
@@ -24,18 +24,17 @@ namespace {
24
24
 
25
25
  typedef Index::idx_t idx_t;
26
26
 
27
-
28
27
  // add translation to all valid labels
29
- void translate_labels (long n, idx_t *labels, long translation)
30
- {
31
- if (translation == 0) return;
28
+ void translate_labels(long n, idx_t* labels, long translation) {
29
+ if (translation == 0)
30
+ return;
32
31
  for (long i = 0; i < n; i++) {
33
- if(labels[i] < 0) continue;
32
+ if (labels[i] < 0)
33
+ continue;
34
34
  labels[i] += translation;
35
35
  }
36
36
  }
37
37
 
38
-
39
38
  /** merge result tables from several shards.
40
39
  * @param all_distances size nshard * n * k
41
40
  * @param all_labels idem
@@ -43,296 +42,313 @@ void translate_labels (long n, idx_t *labels, long translation)
43
42
  */
44
43
 
45
44
  template <class IndexClass, class C>
46
- void
47
- merge_tables(long n, long k, long nshard,
48
- typename IndexClass::distance_t *distances,
49
- idx_t *labels,
50
- const std::vector<typename IndexClass::distance_t>& all_distances,
51
- const std::vector<idx_t>& all_labels,
52
- const std::vector<long>& translations) {
53
- if (k == 0) {
54
- return;
55
- }
56
- using distance_t = typename IndexClass::distance_t;
57
-
58
- long stride = n * k;
45
+ void merge_tables(
46
+ long n,
47
+ long k,
48
+ long nshard,
49
+ typename IndexClass::distance_t* distances,
50
+ idx_t* labels,
51
+ const std::vector<typename IndexClass::distance_t>& all_distances,
52
+ const std::vector<idx_t>& all_labels,
53
+ const std::vector<long>& translations) {
54
+ if (k == 0) {
55
+ return;
56
+ }
57
+ using distance_t = typename IndexClass::distance_t;
58
+
59
+ long stride = n * k;
59
60
  #pragma omp parallel
60
- {
61
- std::vector<int> buf (2 * nshard);
62
- int * pointer = buf.data();
63
- int * shard_ids = pointer + nshard;
64
- std::vector<distance_t> buf2 (nshard);
65
- distance_t * heap_vals = buf2.data();
61
+ {
62
+ std::vector<int> buf(2 * nshard);
63
+ int* pointer = buf.data();
64
+ int* shard_ids = pointer + nshard;
65
+ std::vector<distance_t> buf2(nshard);
66
+ distance_t* heap_vals = buf2.data();
66
67
  #pragma omp for
67
- for (long i = 0; i < n; i++) {
68
- // the heap maps values to the shard where they are
69
- // produced.
70
- const distance_t *D_in = all_distances.data() + i * k;
71
- const idx_t *I_in = all_labels.data() + i * k;
72
- int heap_size = 0;
73
-
74
- for (long s = 0; s < nshard; s++) {
75
- pointer[s] = 0;
76
- if (I_in[stride * s] >= 0) {
77
- heap_push<C> (++heap_size, heap_vals, shard_ids,
78
- D_in[stride * s], s);
79
- }
80
- }
81
-
82
- distance_t *D = distances + i * k;
83
- idx_t *I = labels + i * k;
84
-
85
- for (int j = 0; j < k; j++) {
86
- if (heap_size == 0) {
87
- I[j] = -1;
88
- D[j] = C::neutral();
89
- } else {
90
- // pop best element
91
- int s = shard_ids[0];
92
- int & p = pointer[s];
93
- D[j] = heap_vals[0];
94
- I[j] = I_in[stride * s + p] + translations[s];
95
-
96
- heap_pop<C> (heap_size--, heap_vals, shard_ids);
97
- p++;
98
- if (p < k && I_in[stride * s + p] >= 0) {
99
- heap_push<C> (++heap_size, heap_vals, shard_ids,
100
- D_in[stride * s + p], s);
101
- }
68
+ for (long i = 0; i < n; i++) {
69
+ // the heap maps values to the shard where they are
70
+ // produced.
71
+ const distance_t* D_in = all_distances.data() + i * k;
72
+ const idx_t* I_in = all_labels.data() + i * k;
73
+ int heap_size = 0;
74
+
75
+ for (long s = 0; s < nshard; s++) {
76
+ pointer[s] = 0;
77
+ if (I_in[stride * s] >= 0) {
78
+ heap_push<C>(
79
+ ++heap_size,
80
+ heap_vals,
81
+ shard_ids,
82
+ D_in[stride * s],
83
+ s);
84
+ }
85
+ }
86
+
87
+ distance_t* D = distances + i * k;
88
+ idx_t* I = labels + i * k;
89
+
90
+ for (int j = 0; j < k; j++) {
91
+ if (heap_size == 0) {
92
+ I[j] = -1;
93
+ D[j] = C::neutral();
94
+ } else {
95
+ // pop best element
96
+ int s = shard_ids[0];
97
+ int& p = pointer[s];
98
+ D[j] = heap_vals[0];
99
+ I[j] = I_in[stride * s + p] + translations[s];
100
+
101
+ heap_pop<C>(heap_size--, heap_vals, shard_ids);
102
+ p++;
103
+ if (p < k && I_in[stride * s + p] >= 0) {
104
+ heap_push<C>(
105
+ ++heap_size,
106
+ heap_vals,
107
+ shard_ids,
108
+ D_in[stride * s + p],
109
+ s);
110
+ }
111
+ }
112
+ }
102
113
  }
103
- }
104
114
  }
105
- }
106
115
  }
107
116
 
108
117
  } // anonymous namespace
109
118
 
110
119
  template <typename IndexT>
111
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(idx_t d,
112
- bool threaded,
113
- bool successive_ids)
114
- : ThreadedIndex<IndexT>(d, threaded),
115
- successive_ids(successive_ids) {
116
- }
120
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
121
+ idx_t d,
122
+ bool threaded,
123
+ bool successive_ids)
124
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
117
125
 
118
126
  template <typename IndexT>
119
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(int d,
120
- bool threaded,
121
- bool successive_ids)
122
- : ThreadedIndex<IndexT>(d, threaded),
123
- successive_ids(successive_ids) {
124
- }
127
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
128
+ int d,
129
+ bool threaded,
130
+ bool successive_ids)
131
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
125
132
 
126
133
  template <typename IndexT>
127
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(bool threaded,
128
- bool successive_ids)
129
- : ThreadedIndex<IndexT>(threaded),
130
- successive_ids(successive_ids) {
131
- }
134
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
135
+ bool threaded,
136
+ bool successive_ids)
137
+ : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {}
132
138
 
133
139
  template <typename IndexT>
134
- void
135
- IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
136
- syncWithSubIndexes();
140
+ void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
141
+ syncWithSubIndexes();
137
142
  }
138
143
 
139
144
  template <typename IndexT>
140
- void
141
- IndexShardsTemplate<IndexT>::onAfterRemoveIndex(IndexT* index /* unused */) {
142
- syncWithSubIndexes();
145
+ void IndexShardsTemplate<IndexT>::onAfterRemoveIndex(
146
+ IndexT* index /* unused */) {
147
+ syncWithSubIndexes();
143
148
  }
144
149
 
145
150
  // FIXME: assumes that nothing is currently running on the sub-indexes, which is
146
151
  // true with the normal API, but should use the runOnIndex API instead
147
152
  template <typename IndexT>
148
- void
149
- IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
150
- if (!this->count()) {
151
- this->is_trained = false;
152
- this->ntotal = 0;
153
-
154
- return;
155
- }
156
-
157
- auto firstIndex = this->at(0);
158
- this->metric_type = firstIndex->metric_type;
159
- this->is_trained = firstIndex->is_trained;
160
- this->ntotal = firstIndex->ntotal;
161
-
162
- for (int i = 1; i < this->count(); ++i) {
163
- auto index = this->at(i);
164
- FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
165
- FAISS_THROW_IF_NOT(this->d == index->d);
166
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
167
-
168
- this->ntotal += index->ntotal;
169
- }
153
+ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
154
+ if (!this->count()) {
155
+ this->is_trained = false;
156
+ this->ntotal = 0;
157
+
158
+ return;
159
+ }
160
+
161
+ auto firstIndex = this->at(0);
162
+ this->metric_type = firstIndex->metric_type;
163
+ this->is_trained = firstIndex->is_trained;
164
+ this->ntotal = firstIndex->ntotal;
165
+
166
+ for (int i = 1; i < this->count(); ++i) {
167
+ auto index = this->at(i);
168
+ FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
169
+ FAISS_THROW_IF_NOT(this->d == index->d);
170
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
171
+
172
+ this->ntotal += index->ntotal;
173
+ }
170
174
  }
171
175
 
172
176
  // No metric_type for IndexBinary
173
177
  template <>
174
- void
175
- IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
176
- if (!this->count()) {
177
- this->is_trained = false;
178
- this->ntotal = 0;
179
-
180
- return;
181
- }
182
-
183
- auto firstIndex = this->at(0);
184
- this->is_trained = firstIndex->is_trained;
185
- this->ntotal = firstIndex->ntotal;
186
-
187
- for (int i = 1; i < this->count(); ++i) {
188
- auto index = this->at(i);
189
- FAISS_THROW_IF_NOT(this->d == index->d);
190
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
191
-
192
- this->ntotal += index->ntotal;
193
- }
178
+ void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
179
+ if (!this->count()) {
180
+ this->is_trained = false;
181
+ this->ntotal = 0;
182
+
183
+ return;
184
+ }
185
+
186
+ auto firstIndex = this->at(0);
187
+ this->is_trained = firstIndex->is_trained;
188
+ this->ntotal = firstIndex->ntotal;
189
+
190
+ for (int i = 1; i < this->count(); ++i) {
191
+ auto index = this->at(i);
192
+ FAISS_THROW_IF_NOT(this->d == index->d);
193
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
194
+
195
+ this->ntotal += index->ntotal;
196
+ }
194
197
  }
195
198
 
196
199
  template <typename IndexT>
197
- void
198
- IndexShardsTemplate<IndexT>::train(idx_t n,
199
- const component_t *x) {
200
- auto fn =
201
- [n, x](int no, IndexT *index) {
202
- if (index->verbose) {
203
- printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
- }
205
-
206
- index->train(n, x);
207
-
208
- if (index->verbose) {
209
- printf("end train shard %d\n", no);
210
- }
200
+ void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
201
+ auto fn = [n, x](int no, IndexT* index) {
202
+ if (index->verbose) {
203
+ printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
+ }
205
+
206
+ index->train(n, x);
207
+
208
+ if (index->verbose) {
209
+ printf("end train shard %d\n", no);
210
+ }
211
211
  };
212
212
 
213
- this->runOnIndex(fn);
214
- syncWithSubIndexes();
213
+ this->runOnIndex(fn);
214
+ syncWithSubIndexes();
215
215
  }
216
216
 
217
217
  template <typename IndexT>
218
- void
219
- IndexShardsTemplate<IndexT>::add(idx_t n,
220
- const component_t *x) {
221
- add_with_ids(n, x, nullptr);
218
+ void IndexShardsTemplate<IndexT>::add(idx_t n, const component_t* x) {
219
+ add_with_ids(n, x, nullptr);
222
220
  }
223
221
 
224
222
  template <typename IndexT>
225
- void
226
- IndexShardsTemplate<IndexT>::add_with_ids(idx_t n,
227
- const component_t * x,
228
- const idx_t *xids) {
229
-
230
- FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
231
- "It makes no sense to pass in ids and "
232
- "request them to be shifted");
233
-
234
- if (successive_ids) {
235
- FAISS_THROW_IF_NOT_MSG(!xids,
236
- "It makes no sense to pass in ids and "
237
- "request them to be shifted");
238
- FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
239
- "when adding to IndexShards with sucessive_ids, "
240
- "only add() in a single pass is supported");
241
- }
242
-
243
- idx_t nshard = this->count();
244
- const idx_t *ids = xids;
245
-
246
- std::vector<idx_t> aids;
247
-
248
- if (!ids && !successive_ids) {
249
- aids.resize(n);
250
-
251
- for (idx_t i = 0; i < n; i++) {
252
- aids[i] = this->ntotal + i;
223
+ void IndexShardsTemplate<IndexT>::add_with_ids(
224
+ idx_t n,
225
+ const component_t* x,
226
+ const idx_t* xids) {
227
+ FAISS_THROW_IF_NOT_MSG(
228
+ !(successive_ids && xids),
229
+ "It makes no sense to pass in ids and "
230
+ "request them to be shifted");
231
+
232
+ if (successive_ids) {
233
+ FAISS_THROW_IF_NOT_MSG(
234
+ !xids,
235
+ "It makes no sense to pass in ids and "
236
+ "request them to be shifted");
237
+ FAISS_THROW_IF_NOT_MSG(
238
+ this->ntotal == 0,
239
+ "when adding to IndexShards with sucessive_ids, "
240
+ "only add() in a single pass is supported");
253
241
  }
254
242
 
255
- ids = aids.data();
256
- }
243
+ idx_t nshard = this->count();
244
+ const idx_t* ids = xids;
245
+
246
+ std::vector<idx_t> aids;
247
+
248
+ if (!ids && !successive_ids) {
249
+ aids.resize(n);
250
+
251
+ for (idx_t i = 0; i < n; i++) {
252
+ aids[i] = this->ntotal + i;
253
+ }
254
+
255
+ ids = aids.data();
256
+ }
257
257
 
258
- size_t components_per_vec =
259
- sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
258
+ size_t components_per_vec =
259
+ sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
260
260
 
261
- auto fn =
262
- [n, ids, x, nshard, components_per_vec](int no, IndexT *index) {
263
- idx_t i0 = (idx_t) no * n / nshard;
264
- idx_t i1 = ((idx_t) no + 1) * n / nshard;
265
- auto x0 = x + i0 * components_per_vec;
261
+ auto fn = [n, ids, x, nshard, components_per_vec](int no, IndexT* index) {
262
+ idx_t i0 = (idx_t)no * n / nshard;
263
+ idx_t i1 = ((idx_t)no + 1) * n / nshard;
264
+ auto x0 = x + i0 * components_per_vec;
266
265
 
267
- if (index->verbose) {
268
- printf ("begin add shard %d on %" PRId64 " points\n", no, n);
269
- }
266
+ if (index->verbose) {
267
+ printf("begin add shard %d on %" PRId64 " points\n", no, n);
268
+ }
270
269
 
271
- if (ids) {
272
- index->add_with_ids (i1 - i0, x0, ids + i0);
273
- } else {
274
- index->add (i1 - i0, x0);
275
- }
270
+ if (ids) {
271
+ index->add_with_ids(i1 - i0, x0, ids + i0);
272
+ } else {
273
+ index->add(i1 - i0, x0);
274
+ }
276
275
 
277
- if (index->verbose) {
278
- printf ("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
279
- }
276
+ if (index->verbose) {
277
+ printf("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
278
+ }
280
279
  };
281
280
 
282
- this->runOnIndex(fn);
283
- syncWithSubIndexes();
281
+ this->runOnIndex(fn);
282
+ syncWithSubIndexes();
284
283
  }
285
284
 
286
285
  template <typename IndexT>
287
- void
288
- IndexShardsTemplate<IndexT>::search(idx_t n,
289
- const component_t *x,
290
- idx_t k,
291
- distance_t *distances,
292
- idx_t *labels) const {
293
- long nshard = this->count();
294
-
295
- std::vector<distance_t> all_distances(nshard * k * n);
296
- std::vector<idx_t> all_labels(nshard * k * n);
297
-
298
- auto fn =
299
- [n, k, x, &all_distances, &all_labels](int no, const IndexT *index) {
300
- if (index->verbose) {
301
- printf ("begin query shard %d on %" PRId64 " points\n", no, n);
302
- }
303
-
304
- index->search (n, x, k,
305
- all_distances.data() + no * k * n,
306
- all_labels.data() + no * k * n);
307
-
308
- if (index->verbose) {
309
- printf ("end query shard %d\n", no);
310
- }
286
+ void IndexShardsTemplate<IndexT>::search(
287
+ idx_t n,
288
+ const component_t* x,
289
+ idx_t k,
290
+ distance_t* distances,
291
+ idx_t* labels) const {
292
+ FAISS_THROW_IF_NOT(k > 0);
293
+
294
+ long nshard = this->count();
295
+
296
+ std::vector<distance_t> all_distances(nshard * k * n);
297
+ std::vector<idx_t> all_labels(nshard * k * n);
298
+
299
+ auto fn = [n, k, x, &all_distances, &all_labels](
300
+ int no, const IndexT* index) {
301
+ if (index->verbose) {
302
+ printf("begin query shard %d on %" PRId64 " points\n", no, n);
303
+ }
304
+
305
+ index->search(
306
+ n,
307
+ x,
308
+ k,
309
+ all_distances.data() + no * k * n,
310
+ all_labels.data() + no * k * n);
311
+
312
+ if (index->verbose) {
313
+ printf("end query shard %d\n", no);
314
+ }
311
315
  };
312
316
 
313
- this->runOnIndex(fn);
317
+ this->runOnIndex(fn);
318
+
319
+ std::vector<long> translations(nshard, 0);
314
320
 
315
- std::vector<long> translations(nshard, 0);
321
+ // Because we just called runOnIndex above, it is safe to access the
322
+ // sub-index ntotal here
323
+ if (successive_ids) {
324
+ translations[0] = 0;
316
325
 
317
- // Because we just called runOnIndex above, it is safe to access the sub-index
318
- // ntotal here
319
- if (successive_ids) {
320
- translations[0] = 0;
326
+ for (int s = 0; s + 1 < nshard; s++) {
327
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
328
+ }
329
+ }
321
330
 
322
- for (int s = 0; s + 1 < nshard; s++) {
323
- translations[s + 1] = translations[s] + this->at(s)->ntotal;
331
+ if (this->metric_type == METRIC_L2) {
332
+ merge_tables<IndexT, CMin<distance_t, int>>(
333
+ n,
334
+ k,
335
+ nshard,
336
+ distances,
337
+ labels,
338
+ all_distances,
339
+ all_labels,
340
+ translations);
341
+ } else {
342
+ merge_tables<IndexT, CMax<distance_t, int>>(
343
+ n,
344
+ k,
345
+ nshard,
346
+ distances,
347
+ labels,
348
+ all_distances,
349
+ all_labels,
350
+ translations);
324
351
  }
325
- }
326
-
327
- if (this->metric_type == METRIC_L2) {
328
- merge_tables<IndexT, CMin<distance_t, int>>(
329
- n, k, nshard, distances, labels,
330
- all_distances, all_labels, translations);
331
- } else {
332
- merge_tables<IndexT, CMax<distance_t, int>>(
333
- n, k, nshard, distances, labels,
334
- all_distances, all_labels, translations);
335
- }
336
352
  }
337
353
 
338
354
  // explicit instanciations