faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -16,7 +16,6 @@
16
16
  #include <faiss/IndexIVF.h>
17
17
  #include <faiss/impl/ScalarQuantizer.h>
18
18
 
19
-
20
19
  namespace faiss {
21
20
 
22
21
  /**
@@ -25,10 +24,7 @@ namespace faiss {
25
24
  * (default).
26
25
  */
27
26
 
28
-
29
-
30
-
31
- struct IndexScalarQuantizer: Index {
27
+ struct IndexScalarQuantizer : Index {
32
28
  /// Used to encode the vectors
33
29
  ScalarQuantizer sq;
34
30
 
@@ -43,22 +39,23 @@ struct IndexScalarQuantizer: Index {
43
39
  * @param M number of subquantizers
44
40
  * @param nbits number of bit per subvector index
45
41
  */
46
- IndexScalarQuantizer (int d,
47
- ScalarQuantizer::QuantizerType qtype,
48
- MetricType metric = METRIC_L2);
42
+ IndexScalarQuantizer(
43
+ int d,
44
+ ScalarQuantizer::QuantizerType qtype,
45
+ MetricType metric = METRIC_L2);
49
46
 
50
- IndexScalarQuantizer ();
47
+ IndexScalarQuantizer();
51
48
 
52
49
  void train(idx_t n, const float* x) override;
53
50
 
54
51
  void add(idx_t n, const float* x) override;
55
52
 
56
53
  void search(
57
- idx_t n,
58
- const float* x,
59
- idx_t k,
60
- float* distances,
61
- idx_t* labels) const override;
54
+ idx_t n,
55
+ const float* x,
56
+ idx_t k,
57
+ float* distances,
58
+ idx_t* labels) const override;
62
59
 
63
60
  void reset() override;
64
61
 
@@ -66,62 +63,61 @@ struct IndexScalarQuantizer: Index {
66
63
 
67
64
  void reconstruct(idx_t key, float* recons) const override;
68
65
 
69
- DistanceComputer *get_distance_computer () const override;
66
+ DistanceComputer* get_distance_computer() const override;
70
67
 
71
68
  /* standalone codec interface */
72
- size_t sa_code_size () const override;
73
-
74
- void sa_encode (idx_t n, const float *x,
75
- uint8_t *bytes) const override;
76
-
77
- void sa_decode (idx_t n, const uint8_t *bytes,
78
- float *x) const override;
69
+ size_t sa_code_size() const override;
79
70
 
71
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
80
72
 
73
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
81
74
  };
82
75
 
83
-
84
- /** An IVF implementation where the components of the residuals are
76
+ /** An IVF implementation where the components of the residuals are
85
77
  * encoded with a scalar quantizer. All distance computations
86
78
  * are asymmetric, so the encoded vectors are decoded and approximate
87
79
  * distances are computed.
88
80
  */
89
81
 
90
- struct IndexIVFScalarQuantizer: IndexIVF {
82
+ struct IndexIVFScalarQuantizer : IndexIVF {
91
83
  ScalarQuantizer sq;
92
84
  bool by_residual;
93
85
 
94
- IndexIVFScalarQuantizer(Index *quantizer, size_t d, size_t nlist,
95
- ScalarQuantizer::QuantizerType qtype,
96
- MetricType metric = METRIC_L2,
97
- bool encode_residual = true);
86
+ IndexIVFScalarQuantizer(
87
+ Index* quantizer,
88
+ size_t d,
89
+ size_t nlist,
90
+ ScalarQuantizer::QuantizerType qtype,
91
+ MetricType metric = METRIC_L2,
92
+ bool encode_residual = true);
98
93
 
99
94
  IndexIVFScalarQuantizer();
100
95
 
101
96
  void train_residual(idx_t n, const float* x) override;
102
97
 
103
- void encode_vectors(idx_t n, const float* x,
104
- const idx_t *list_nos,
105
- uint8_t * codes,
106
- bool include_listnos=false) const override;
107
-
108
- void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
98
+ void encode_vectors(
99
+ idx_t n,
100
+ const float* x,
101
+ const idx_t* list_nos,
102
+ uint8_t* codes,
103
+ bool include_listnos = false) const override;
109
104
 
110
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
111
- const override;
105
+ void add_core(
106
+ idx_t n,
107
+ const float* x,
108
+ const idx_t* xids,
109
+ const idx_t* precomputed_idx) override;
112
110
 
111
+ InvertedListScanner* get_InvertedListScanner(
112
+ bool store_pairs) const override;
113
113
 
114
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
115
- float* recons) const override;
114
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
115
+ const override;
116
116
 
117
117
  /* standalone codec interface */
118
- void sa_decode (idx_t n, const uint8_t *bytes,
119
- float *x) const override;
120
-
118
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
121
119
  };
122
120
 
123
-
124
- }
125
-
121
+ } // namespace faiss
126
122
 
127
123
  #endif
@@ -24,18 +24,17 @@ namespace {
24
24
 
25
25
  typedef Index::idx_t idx_t;
26
26
 
27
-
28
27
  // add translation to all valid labels
29
- void translate_labels (long n, idx_t *labels, long translation)
30
- {
31
- if (translation == 0) return;
28
+ void translate_labels(long n, idx_t* labels, long translation) {
29
+ if (translation == 0)
30
+ return;
32
31
  for (long i = 0; i < n; i++) {
33
- if(labels[i] < 0) continue;
32
+ if (labels[i] < 0)
33
+ continue;
34
34
  labels[i] += translation;
35
35
  }
36
36
  }
37
37
 
38
-
39
38
  /** merge result tables from several shards.
40
39
  * @param all_distances size nshard * n * k
41
40
  * @param all_labels idem
@@ -43,296 +42,313 @@ void translate_labels (long n, idx_t *labels, long translation)
43
42
  */
44
43
 
45
44
  template <class IndexClass, class C>
46
- void
47
- merge_tables(long n, long k, long nshard,
48
- typename IndexClass::distance_t *distances,
49
- idx_t *labels,
50
- const std::vector<typename IndexClass::distance_t>& all_distances,
51
- const std::vector<idx_t>& all_labels,
52
- const std::vector<long>& translations) {
53
- if (k == 0) {
54
- return;
55
- }
56
- using distance_t = typename IndexClass::distance_t;
57
-
58
- long stride = n * k;
45
+ void merge_tables(
46
+ long n,
47
+ long k,
48
+ long nshard,
49
+ typename IndexClass::distance_t* distances,
50
+ idx_t* labels,
51
+ const std::vector<typename IndexClass::distance_t>& all_distances,
52
+ const std::vector<idx_t>& all_labels,
53
+ const std::vector<long>& translations) {
54
+ if (k == 0) {
55
+ return;
56
+ }
57
+ using distance_t = typename IndexClass::distance_t;
58
+
59
+ long stride = n * k;
59
60
  #pragma omp parallel
60
- {
61
- std::vector<int> buf (2 * nshard);
62
- int * pointer = buf.data();
63
- int * shard_ids = pointer + nshard;
64
- std::vector<distance_t> buf2 (nshard);
65
- distance_t * heap_vals = buf2.data();
61
+ {
62
+ std::vector<int> buf(2 * nshard);
63
+ int* pointer = buf.data();
64
+ int* shard_ids = pointer + nshard;
65
+ std::vector<distance_t> buf2(nshard);
66
+ distance_t* heap_vals = buf2.data();
66
67
  #pragma omp for
67
- for (long i = 0; i < n; i++) {
68
- // the heap maps values to the shard where they are
69
- // produced.
70
- const distance_t *D_in = all_distances.data() + i * k;
71
- const idx_t *I_in = all_labels.data() + i * k;
72
- int heap_size = 0;
73
-
74
- for (long s = 0; s < nshard; s++) {
75
- pointer[s] = 0;
76
- if (I_in[stride * s] >= 0) {
77
- heap_push<C> (++heap_size, heap_vals, shard_ids,
78
- D_in[stride * s], s);
79
- }
80
- }
81
-
82
- distance_t *D = distances + i * k;
83
- idx_t *I = labels + i * k;
84
-
85
- for (int j = 0; j < k; j++) {
86
- if (heap_size == 0) {
87
- I[j] = -1;
88
- D[j] = C::neutral();
89
- } else {
90
- // pop best element
91
- int s = shard_ids[0];
92
- int & p = pointer[s];
93
- D[j] = heap_vals[0];
94
- I[j] = I_in[stride * s + p] + translations[s];
95
-
96
- heap_pop<C> (heap_size--, heap_vals, shard_ids);
97
- p++;
98
- if (p < k && I_in[stride * s + p] >= 0) {
99
- heap_push<C> (++heap_size, heap_vals, shard_ids,
100
- D_in[stride * s + p], s);
101
- }
68
+ for (long i = 0; i < n; i++) {
69
+ // the heap maps values to the shard where they are
70
+ // produced.
71
+ const distance_t* D_in = all_distances.data() + i * k;
72
+ const idx_t* I_in = all_labels.data() + i * k;
73
+ int heap_size = 0;
74
+
75
+ for (long s = 0; s < nshard; s++) {
76
+ pointer[s] = 0;
77
+ if (I_in[stride * s] >= 0) {
78
+ heap_push<C>(
79
+ ++heap_size,
80
+ heap_vals,
81
+ shard_ids,
82
+ D_in[stride * s],
83
+ s);
84
+ }
85
+ }
86
+
87
+ distance_t* D = distances + i * k;
88
+ idx_t* I = labels + i * k;
89
+
90
+ for (int j = 0; j < k; j++) {
91
+ if (heap_size == 0) {
92
+ I[j] = -1;
93
+ D[j] = C::neutral();
94
+ } else {
95
+ // pop best element
96
+ int s = shard_ids[0];
97
+ int& p = pointer[s];
98
+ D[j] = heap_vals[0];
99
+ I[j] = I_in[stride * s + p] + translations[s];
100
+
101
+ heap_pop<C>(heap_size--, heap_vals, shard_ids);
102
+ p++;
103
+ if (p < k && I_in[stride * s + p] >= 0) {
104
+ heap_push<C>(
105
+ ++heap_size,
106
+ heap_vals,
107
+ shard_ids,
108
+ D_in[stride * s + p],
109
+ s);
110
+ }
111
+ }
112
+ }
102
113
  }
103
- }
104
114
  }
105
- }
106
115
  }
107
116
 
108
117
  } // anonymous namespace
109
118
 
110
119
  template <typename IndexT>
111
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(idx_t d,
112
- bool threaded,
113
- bool successive_ids)
114
- : ThreadedIndex<IndexT>(d, threaded),
115
- successive_ids(successive_ids) {
116
- }
120
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
121
+ idx_t d,
122
+ bool threaded,
123
+ bool successive_ids)
124
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
117
125
 
118
126
  template <typename IndexT>
119
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(int d,
120
- bool threaded,
121
- bool successive_ids)
122
- : ThreadedIndex<IndexT>(d, threaded),
123
- successive_ids(successive_ids) {
124
- }
127
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
128
+ int d,
129
+ bool threaded,
130
+ bool successive_ids)
131
+ : ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
125
132
 
126
133
  template <typename IndexT>
127
- IndexShardsTemplate<IndexT>::IndexShardsTemplate(bool threaded,
128
- bool successive_ids)
129
- : ThreadedIndex<IndexT>(threaded),
130
- successive_ids(successive_ids) {
131
- }
134
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
135
+ bool threaded,
136
+ bool successive_ids)
137
+ : ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {}
132
138
 
133
139
  template <typename IndexT>
134
- void
135
- IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
136
- syncWithSubIndexes();
140
+ void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
141
+ syncWithSubIndexes();
137
142
  }
138
143
 
139
144
  template <typename IndexT>
140
- void
141
- IndexShardsTemplate<IndexT>::onAfterRemoveIndex(IndexT* index /* unused */) {
142
- syncWithSubIndexes();
145
+ void IndexShardsTemplate<IndexT>::onAfterRemoveIndex(
146
+ IndexT* index /* unused */) {
147
+ syncWithSubIndexes();
143
148
  }
144
149
 
145
150
  // FIXME: assumes that nothing is currently running on the sub-indexes, which is
146
151
  // true with the normal API, but should use the runOnIndex API instead
147
152
  template <typename IndexT>
148
- void
149
- IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
150
- if (!this->count()) {
151
- this->is_trained = false;
152
- this->ntotal = 0;
153
-
154
- return;
155
- }
156
-
157
- auto firstIndex = this->at(0);
158
- this->metric_type = firstIndex->metric_type;
159
- this->is_trained = firstIndex->is_trained;
160
- this->ntotal = firstIndex->ntotal;
161
-
162
- for (int i = 1; i < this->count(); ++i) {
163
- auto index = this->at(i);
164
- FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
165
- FAISS_THROW_IF_NOT(this->d == index->d);
166
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
167
-
168
- this->ntotal += index->ntotal;
169
- }
153
+ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
154
+ if (!this->count()) {
155
+ this->is_trained = false;
156
+ this->ntotal = 0;
157
+
158
+ return;
159
+ }
160
+
161
+ auto firstIndex = this->at(0);
162
+ this->metric_type = firstIndex->metric_type;
163
+ this->is_trained = firstIndex->is_trained;
164
+ this->ntotal = firstIndex->ntotal;
165
+
166
+ for (int i = 1; i < this->count(); ++i) {
167
+ auto index = this->at(i);
168
+ FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
169
+ FAISS_THROW_IF_NOT(this->d == index->d);
170
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
171
+
172
+ this->ntotal += index->ntotal;
173
+ }
170
174
  }
171
175
 
172
176
  // No metric_type for IndexBinary
173
177
  template <>
174
- void
175
- IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
176
- if (!this->count()) {
177
- this->is_trained = false;
178
- this->ntotal = 0;
179
-
180
- return;
181
- }
182
-
183
- auto firstIndex = this->at(0);
184
- this->is_trained = firstIndex->is_trained;
185
- this->ntotal = firstIndex->ntotal;
186
-
187
- for (int i = 1; i < this->count(); ++i) {
188
- auto index = this->at(i);
189
- FAISS_THROW_IF_NOT(this->d == index->d);
190
- FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
191
-
192
- this->ntotal += index->ntotal;
193
- }
178
+ void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
179
+ if (!this->count()) {
180
+ this->is_trained = false;
181
+ this->ntotal = 0;
182
+
183
+ return;
184
+ }
185
+
186
+ auto firstIndex = this->at(0);
187
+ this->is_trained = firstIndex->is_trained;
188
+ this->ntotal = firstIndex->ntotal;
189
+
190
+ for (int i = 1; i < this->count(); ++i) {
191
+ auto index = this->at(i);
192
+ FAISS_THROW_IF_NOT(this->d == index->d);
193
+ FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
194
+
195
+ this->ntotal += index->ntotal;
196
+ }
194
197
  }
195
198
 
196
199
  template <typename IndexT>
197
- void
198
- IndexShardsTemplate<IndexT>::train(idx_t n,
199
- const component_t *x) {
200
- auto fn =
201
- [n, x](int no, IndexT *index) {
202
- if (index->verbose) {
203
- printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
- }
205
-
206
- index->train(n, x);
207
-
208
- if (index->verbose) {
209
- printf("end train shard %d\n", no);
210
- }
200
+ void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
201
+ auto fn = [n, x](int no, IndexT* index) {
202
+ if (index->verbose) {
203
+ printf("begin train shard %d on %" PRId64 " points\n", no, n);
204
+ }
205
+
206
+ index->train(n, x);
207
+
208
+ if (index->verbose) {
209
+ printf("end train shard %d\n", no);
210
+ }
211
211
  };
212
212
 
213
- this->runOnIndex(fn);
214
- syncWithSubIndexes();
213
+ this->runOnIndex(fn);
214
+ syncWithSubIndexes();
215
215
  }
216
216
 
217
217
  template <typename IndexT>
218
- void
219
- IndexShardsTemplate<IndexT>::add(idx_t n,
220
- const component_t *x) {
221
- add_with_ids(n, x, nullptr);
218
+ void IndexShardsTemplate<IndexT>::add(idx_t n, const component_t* x) {
219
+ add_with_ids(n, x, nullptr);
222
220
  }
223
221
 
224
222
  template <typename IndexT>
225
- void
226
- IndexShardsTemplate<IndexT>::add_with_ids(idx_t n,
227
- const component_t * x,
228
- const idx_t *xids) {
229
-
230
- FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
231
- "It makes no sense to pass in ids and "
232
- "request them to be shifted");
233
-
234
- if (successive_ids) {
235
- FAISS_THROW_IF_NOT_MSG(!xids,
236
- "It makes no sense to pass in ids and "
237
- "request them to be shifted");
238
- FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
239
- "when adding to IndexShards with sucessive_ids, "
240
- "only add() in a single pass is supported");
241
- }
242
-
243
- idx_t nshard = this->count();
244
- const idx_t *ids = xids;
245
-
246
- std::vector<idx_t> aids;
247
-
248
- if (!ids && !successive_ids) {
249
- aids.resize(n);
250
-
251
- for (idx_t i = 0; i < n; i++) {
252
- aids[i] = this->ntotal + i;
223
+ void IndexShardsTemplate<IndexT>::add_with_ids(
224
+ idx_t n,
225
+ const component_t* x,
226
+ const idx_t* xids) {
227
+ FAISS_THROW_IF_NOT_MSG(
228
+ !(successive_ids && xids),
229
+ "It makes no sense to pass in ids and "
230
+ "request them to be shifted");
231
+
232
+ if (successive_ids) {
233
+ FAISS_THROW_IF_NOT_MSG(
234
+ !xids,
235
+ "It makes no sense to pass in ids and "
236
+ "request them to be shifted");
237
+ FAISS_THROW_IF_NOT_MSG(
238
+ this->ntotal == 0,
239
+ "when adding to IndexShards with sucessive_ids, "
240
+ "only add() in a single pass is supported");
253
241
  }
254
242
 
255
- ids = aids.data();
256
- }
243
+ idx_t nshard = this->count();
244
+ const idx_t* ids = xids;
245
+
246
+ std::vector<idx_t> aids;
247
+
248
+ if (!ids && !successive_ids) {
249
+ aids.resize(n);
250
+
251
+ for (idx_t i = 0; i < n; i++) {
252
+ aids[i] = this->ntotal + i;
253
+ }
254
+
255
+ ids = aids.data();
256
+ }
257
257
 
258
- size_t components_per_vec =
259
- sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
258
+ size_t components_per_vec =
259
+ sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
260
260
 
261
- auto fn =
262
- [n, ids, x, nshard, components_per_vec](int no, IndexT *index) {
263
- idx_t i0 = (idx_t) no * n / nshard;
264
- idx_t i1 = ((idx_t) no + 1) * n / nshard;
265
- auto x0 = x + i0 * components_per_vec;
261
+ auto fn = [n, ids, x, nshard, components_per_vec](int no, IndexT* index) {
262
+ idx_t i0 = (idx_t)no * n / nshard;
263
+ idx_t i1 = ((idx_t)no + 1) * n / nshard;
264
+ auto x0 = x + i0 * components_per_vec;
266
265
 
267
- if (index->verbose) {
268
- printf ("begin add shard %d on %" PRId64 " points\n", no, n);
269
- }
266
+ if (index->verbose) {
267
+ printf("begin add shard %d on %" PRId64 " points\n", no, n);
268
+ }
270
269
 
271
- if (ids) {
272
- index->add_with_ids (i1 - i0, x0, ids + i0);
273
- } else {
274
- index->add (i1 - i0, x0);
275
- }
270
+ if (ids) {
271
+ index->add_with_ids(i1 - i0, x0, ids + i0);
272
+ } else {
273
+ index->add(i1 - i0, x0);
274
+ }
276
275
 
277
- if (index->verbose) {
278
- printf ("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
279
- }
276
+ if (index->verbose) {
277
+ printf("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
278
+ }
280
279
  };
281
280
 
282
- this->runOnIndex(fn);
283
- syncWithSubIndexes();
281
+ this->runOnIndex(fn);
282
+ syncWithSubIndexes();
284
283
  }
285
284
 
286
285
  template <typename IndexT>
287
- void
288
- IndexShardsTemplate<IndexT>::search(idx_t n,
289
- const component_t *x,
290
- idx_t k,
291
- distance_t *distances,
292
- idx_t *labels) const {
293
- long nshard = this->count();
294
-
295
- std::vector<distance_t> all_distances(nshard * k * n);
296
- std::vector<idx_t> all_labels(nshard * k * n);
297
-
298
- auto fn =
299
- [n, k, x, &all_distances, &all_labels](int no, const IndexT *index) {
300
- if (index->verbose) {
301
- printf ("begin query shard %d on %" PRId64 " points\n", no, n);
302
- }
303
-
304
- index->search (n, x, k,
305
- all_distances.data() + no * k * n,
306
- all_labels.data() + no * k * n);
307
-
308
- if (index->verbose) {
309
- printf ("end query shard %d\n", no);
310
- }
286
+ void IndexShardsTemplate<IndexT>::search(
287
+ idx_t n,
288
+ const component_t* x,
289
+ idx_t k,
290
+ distance_t* distances,
291
+ idx_t* labels) const {
292
+ FAISS_THROW_IF_NOT(k > 0);
293
+
294
+ long nshard = this->count();
295
+
296
+ std::vector<distance_t> all_distances(nshard * k * n);
297
+ std::vector<idx_t> all_labels(nshard * k * n);
298
+
299
+ auto fn = [n, k, x, &all_distances, &all_labels](
300
+ int no, const IndexT* index) {
301
+ if (index->verbose) {
302
+ printf("begin query shard %d on %" PRId64 " points\n", no, n);
303
+ }
304
+
305
+ index->search(
306
+ n,
307
+ x,
308
+ k,
309
+ all_distances.data() + no * k * n,
310
+ all_labels.data() + no * k * n);
311
+
312
+ if (index->verbose) {
313
+ printf("end query shard %d\n", no);
314
+ }
311
315
  };
312
316
 
313
- this->runOnIndex(fn);
317
+ this->runOnIndex(fn);
318
+
319
+ std::vector<long> translations(nshard, 0);
314
320
 
315
- std::vector<long> translations(nshard, 0);
321
+ // Because we just called runOnIndex above, it is safe to access the
322
+ // sub-index ntotal here
323
+ if (successive_ids) {
324
+ translations[0] = 0;
316
325
 
317
- // Because we just called runOnIndex above, it is safe to access the sub-index
318
- // ntotal here
319
- if (successive_ids) {
320
- translations[0] = 0;
326
+ for (int s = 0; s + 1 < nshard; s++) {
327
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
328
+ }
329
+ }
321
330
 
322
- for (int s = 0; s + 1 < nshard; s++) {
323
- translations[s + 1] = translations[s] + this->at(s)->ntotal;
331
+ if (this->metric_type == METRIC_L2) {
332
+ merge_tables<IndexT, CMin<distance_t, int>>(
333
+ n,
334
+ k,
335
+ nshard,
336
+ distances,
337
+ labels,
338
+ all_distances,
339
+ all_labels,
340
+ translations);
341
+ } else {
342
+ merge_tables<IndexT, CMax<distance_t, int>>(
343
+ n,
344
+ k,
345
+ nshard,
346
+ distances,
347
+ labels,
348
+ all_distances,
349
+ all_labels,
350
+ translations);
324
351
  }
325
- }
326
-
327
- if (this->metric_type == METRIC_L2) {
328
- merge_tables<IndexT, CMin<distance_t, int>>(
329
- n, k, nshard, distances, labels,
330
- all_distances, all_labels, translations);
331
- } else {
332
- merge_tables<IndexT, CMax<distance_t, int>>(
333
- n, k, nshard, distances, labels,
334
- all_distances, all_labels, translations);
335
- }
336
352
  }
337
353
 
338
354
  // explicit instanciations