faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,127 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_INDEX_SCALAR_QUANTIZER_H
11
+ #define FAISS_INDEX_SCALAR_QUANTIZER_H
12
+
13
+ #include <stdint.h>
14
+ #include <vector>
15
+
16
+ #include <faiss/IndexIVF.h>
17
+ #include <faiss/impl/ScalarQuantizer.h>
18
+
19
+
20
+ namespace faiss {
21
+
22
+ /**
23
+ * The uniform quantizer has a range [vmin, vmax]. The range can be
24
+ * the same for all dimensions (uniform) or specific per dimension
25
+ * (default).
26
+ */
27
+
28
+
29
+
30
+
31
+ struct IndexScalarQuantizer: Index {
32
+ /// Used to encode the vectors
33
+ ScalarQuantizer sq;
34
+
35
+ /// Codes. Size ntotal * pq.code_size
36
+ std::vector<uint8_t> codes;
37
+
38
+ size_t code_size;
39
+
40
+ /** Constructor.
41
+ *
42
+ * @param d dimensionality of the input vectors
43
+ * @param M number of subquantizers
44
+ * @param nbits number of bit per subvector index
45
+ */
46
+ IndexScalarQuantizer (int d,
47
+ ScalarQuantizer::QuantizerType qtype,
48
+ MetricType metric = METRIC_L2);
49
+
50
+ IndexScalarQuantizer ();
51
+
52
+ void train(idx_t n, const float* x) override;
53
+
54
+ void add(idx_t n, const float* x) override;
55
+
56
+ void search(
57
+ idx_t n,
58
+ const float* x,
59
+ idx_t k,
60
+ float* distances,
61
+ idx_t* labels) const override;
62
+
63
+ void reset() override;
64
+
65
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
66
+
67
+ void reconstruct(idx_t key, float* recons) const override;
68
+
69
+ DistanceComputer *get_distance_computer () const override;
70
+
71
+ /* standalone codec interface */
72
+ size_t sa_code_size () const override;
73
+
74
+ void sa_encode (idx_t n, const float *x,
75
+ uint8_t *bytes) const override;
76
+
77
+ void sa_decode (idx_t n, const uint8_t *bytes,
78
+ float *x) const override;
79
+
80
+
81
+ };
82
+
83
+
84
+ /** An IVF implementation where the components of the residuals are
85
+ * encoded with a scalar uniform quantizer. All distance computations
86
+ * are asymmetric, so the encoded vectors are decoded and approximate
87
+ * distances are computed.
88
+ */
89
+
90
+ struct IndexIVFScalarQuantizer: IndexIVF {
91
+ ScalarQuantizer sq;
92
+ bool by_residual;
93
+
94
+ IndexIVFScalarQuantizer(Index *quantizer, size_t d, size_t nlist,
95
+ ScalarQuantizer::QuantizerType qtype,
96
+ MetricType metric = METRIC_L2,
97
+ bool encode_residual = true);
98
+
99
+ IndexIVFScalarQuantizer();
100
+
101
+ void train_residual(idx_t n, const float* x) override;
102
+
103
+ void encode_vectors(idx_t n, const float* x,
104
+ const idx_t *list_nos,
105
+ uint8_t * codes,
106
+ bool include_listnos=false) const override;
107
+
108
+ void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
109
+
110
+ InvertedListScanner *get_InvertedListScanner (bool store_pairs)
111
+ const override;
112
+
113
+
114
+ void reconstruct_from_offset (int64_t list_no, int64_t offset,
115
+ float* recons) const override;
116
+
117
+ /* standalone codec interface */
118
+ void sa_decode (idx_t n, const uint8_t *bytes,
119
+ float *x) const override;
120
+
121
+ };
122
+
123
+
124
+ }
125
+
126
+
127
+ #endif
@@ -0,0 +1,317 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexShards.h>
11
+
12
+ #include <cstdio>
13
+ #include <functional>
14
+
15
+ #include <faiss/impl/FaissAssert.h>
16
+ #include <faiss/utils/Heap.h>
17
+ #include <faiss/utils/WorkerThread.h>
18
+
19
+ namespace faiss {
20
+
21
+ // subroutines
22
+ namespace {
23
+
24
+ typedef Index::idx_t idx_t;
25
+
26
+
27
+ // add translation to all valid labels
28
+ void translate_labels (long n, idx_t *labels, long translation)
29
+ {
30
+ if (translation == 0) return;
31
+ for (long i = 0; i < n; i++) {
32
+ if(labels[i] < 0) continue;
33
+ labels[i] += translation;
34
+ }
35
+ }
36
+
37
+
38
+ /** merge result tables from several shards.
39
+ * @param all_distances size nshard * n * k
40
+ * @param all_labels idem
41
+ * @param translartions label translations to apply, size nshard
42
+ */
43
+
44
+ template <class IndexClass, class C>
45
+ void
46
+ merge_tables(long n, long k, long nshard,
47
+ typename IndexClass::distance_t *distances,
48
+ idx_t *labels,
49
+ const std::vector<typename IndexClass::distance_t>& all_distances,
50
+ const std::vector<idx_t>& all_labels,
51
+ const std::vector<long>& translations) {
52
+ if (k == 0) {
53
+ return;
54
+ }
55
+ using distance_t = typename IndexClass::distance_t;
56
+
57
+ long stride = n * k;
58
+ #pragma omp parallel
59
+ {
60
+ std::vector<int> buf (2 * nshard);
61
+ int * pointer = buf.data();
62
+ int * shard_ids = pointer + nshard;
63
+ std::vector<distance_t> buf2 (nshard);
64
+ distance_t * heap_vals = buf2.data();
65
+ #pragma omp for
66
+ for (long i = 0; i < n; i++) {
67
+ // the heap maps values to the shard where they are
68
+ // produced.
69
+ const distance_t *D_in = all_distances.data() + i * k;
70
+ const idx_t *I_in = all_labels.data() + i * k;
71
+ int heap_size = 0;
72
+
73
+ for (long s = 0; s < nshard; s++) {
74
+ pointer[s] = 0;
75
+ if (I_in[stride * s] >= 0) {
76
+ heap_push<C> (++heap_size, heap_vals, shard_ids,
77
+ D_in[stride * s], s);
78
+ }
79
+ }
80
+
81
+ distance_t *D = distances + i * k;
82
+ idx_t *I = labels + i * k;
83
+
84
+ for (int j = 0; j < k; j++) {
85
+ if (heap_size == 0) {
86
+ I[j] = -1;
87
+ D[j] = C::neutral();
88
+ } else {
89
+ // pop best element
90
+ int s = shard_ids[0];
91
+ int & p = pointer[s];
92
+ D[j] = heap_vals[0];
93
+ I[j] = I_in[stride * s + p] + translations[s];
94
+
95
+ heap_pop<C> (heap_size--, heap_vals, shard_ids);
96
+ p++;
97
+ if (p < k && I_in[stride * s + p] >= 0) {
98
+ heap_push<C> (++heap_size, heap_vals, shard_ids,
99
+ D_in[stride * s + p], s);
100
+ }
101
+ }
102
+ }
103
+ }
104
+ }
105
+ }
106
+
107
+ } // anonymous namespace
108
+
109
+ template <typename IndexT>
110
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(idx_t d,
111
+ bool threaded,
112
+ bool successive_ids)
113
+ : ThreadedIndex<IndexT>(d, threaded),
114
+ successive_ids(successive_ids) {
115
+ }
116
+
117
+ template <typename IndexT>
118
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(int d,
119
+ bool threaded,
120
+ bool successive_ids)
121
+ : ThreadedIndex<IndexT>(d, threaded),
122
+ successive_ids(successive_ids) {
123
+ }
124
+
125
+ template <typename IndexT>
126
+ IndexShardsTemplate<IndexT>::IndexShardsTemplate(bool threaded,
127
+ bool successive_ids)
128
+ : ThreadedIndex<IndexT>(threaded),
129
+ successive_ids(successive_ids) {
130
+ }
131
+
132
+ template <typename IndexT>
133
+ void
134
+ IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
135
+ sync_with_shard_indexes();
136
+ }
137
+
138
+ template <typename IndexT>
139
+ void
140
+ IndexShardsTemplate<IndexT>::onAfterRemoveIndex(IndexT* index /* unused */) {
141
+ sync_with_shard_indexes();
142
+ }
143
+
144
+ template <typename IndexT>
145
+ void
146
+ IndexShardsTemplate<IndexT>::sync_with_shard_indexes() {
147
+ if (!this->count()) {
148
+ this->is_trained = false;
149
+ this->ntotal = 0;
150
+
151
+ return;
152
+ }
153
+
154
+ auto firstIndex = this->at(0);
155
+ this->metric_type = firstIndex->metric_type;
156
+ this->is_trained = firstIndex->is_trained;
157
+ this->ntotal = firstIndex->ntotal;
158
+
159
+ for (int i = 1; i < this->count(); ++i) {
160
+ auto index = this->at(i);
161
+ FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
162
+ FAISS_THROW_IF_NOT(this->d == index->d);
163
+
164
+ this->ntotal += index->ntotal;
165
+ }
166
+ }
167
+
168
+ template <typename IndexT>
169
+ void
170
+ IndexShardsTemplate<IndexT>::train(idx_t n,
171
+ const component_t *x) {
172
+ auto fn =
173
+ [n, x](int no, IndexT *index) {
174
+ if (index->verbose) {
175
+ printf("begin train shard %d on %ld points\n", no, n);
176
+ }
177
+
178
+ index->train(n, x);
179
+
180
+ if (index->verbose) {
181
+ printf("end train shard %d\n", no);
182
+ }
183
+ };
184
+
185
+ this->runOnIndex(fn);
186
+ sync_with_shard_indexes();
187
+ }
188
+
189
+ template <typename IndexT>
190
+ void
191
+ IndexShardsTemplate<IndexT>::add(idx_t n,
192
+ const component_t *x) {
193
+ add_with_ids(n, x, nullptr);
194
+ }
195
+
196
+ template <typename IndexT>
197
+ void
198
+ IndexShardsTemplate<IndexT>::add_with_ids(idx_t n,
199
+ const component_t * x,
200
+ const idx_t *xids) {
201
+
202
+ FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
203
+ "It makes no sense to pass in ids and "
204
+ "request them to be shifted");
205
+
206
+ if (successive_ids) {
207
+ FAISS_THROW_IF_NOT_MSG(!xids,
208
+ "It makes no sense to pass in ids and "
209
+ "request them to be shifted");
210
+ FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
211
+ "when adding to IndexShards with sucessive_ids, "
212
+ "only add() in a single pass is supported");
213
+ }
214
+
215
+ idx_t nshard = this->count();
216
+ const idx_t *ids = xids;
217
+
218
+ std::vector<idx_t> aids;
219
+
220
+ if (!ids && !successive_ids) {
221
+ aids.resize(n);
222
+
223
+ for (idx_t i = 0; i < n; i++) {
224
+ aids[i] = this->ntotal + i;
225
+ }
226
+
227
+ ids = aids.data();
228
+ }
229
+
230
+ size_t components_per_vec =
231
+ sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
232
+
233
+ auto fn =
234
+ [n, ids, x, nshard, components_per_vec](int no, IndexT *index) {
235
+ idx_t i0 = (idx_t) no * n / nshard;
236
+ idx_t i1 = ((idx_t) no + 1) * n / nshard;
237
+ auto x0 = x + i0 * components_per_vec;
238
+
239
+ if (index->verbose) {
240
+ printf ("begin add shard %d on %ld points\n", no, n);
241
+ }
242
+
243
+ if (ids) {
244
+ index->add_with_ids (i1 - i0, x0, ids + i0);
245
+ } else {
246
+ index->add (i1 - i0, x0);
247
+ }
248
+
249
+ if (index->verbose) {
250
+ printf ("end add shard %d on %ld points\n", no, i1 - i0);
251
+ }
252
+ };
253
+
254
+ this->runOnIndex(fn);
255
+
256
+ // This is safe to do here because the current thread controls execution in
257
+ // all threads, and nothing else is happening
258
+ this->ntotal += n;
259
+ }
260
+
261
+ template <typename IndexT>
262
+ void
263
+ IndexShardsTemplate<IndexT>::search(idx_t n,
264
+ const component_t *x,
265
+ idx_t k,
266
+ distance_t *distances,
267
+ idx_t *labels) const {
268
+ long nshard = this->count();
269
+
270
+ std::vector<distance_t> all_distances(nshard * k * n);
271
+ std::vector<idx_t> all_labels(nshard * k * n);
272
+
273
+ auto fn =
274
+ [n, k, x, &all_distances, &all_labels](int no, const IndexT *index) {
275
+ if (index->verbose) {
276
+ printf ("begin query shard %d on %ld points\n", no, n);
277
+ }
278
+
279
+ index->search (n, x, k,
280
+ all_distances.data() + no * k * n,
281
+ all_labels.data() + no * k * n);
282
+
283
+ if (index->verbose) {
284
+ printf ("end query shard %d\n", no);
285
+ }
286
+ };
287
+
288
+ this->runOnIndex(fn);
289
+
290
+ std::vector<long> translations(nshard, 0);
291
+
292
+ // Because we just called runOnIndex above, it is safe to access the sub-index
293
+ // ntotal here
294
+ if (successive_ids) {
295
+ translations[0] = 0;
296
+
297
+ for (int s = 0; s + 1 < nshard; s++) {
298
+ translations[s + 1] = translations[s] + this->at(s)->ntotal;
299
+ }
300
+ }
301
+
302
+ if (this->metric_type == METRIC_L2) {
303
+ merge_tables<IndexT, CMin<distance_t, int>>(
304
+ n, k, nshard, distances, labels,
305
+ all_distances, all_labels, translations);
306
+ } else {
307
+ merge_tables<IndexT, CMax<distance_t, int>>(
308
+ n, k, nshard, distances, labels,
309
+ all_distances, all_labels, translations);
310
+ }
311
+ }
312
+
313
+ // explicit instanciations
314
+ template struct IndexShardsTemplate<Index>;
315
+ template struct IndexShardsTemplate<IndexBinary>;
316
+
317
+ } // namespace faiss
@@ -0,0 +1,100 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/IndexBinary.h>
12
+ #include <faiss/impl/ThreadedIndex.h>
13
+
14
+ namespace faiss {
15
+
16
+ /**
17
+ * Index that concatenates the results from several sub-indexes
18
+ */
19
+ template <typename IndexT>
20
+ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
21
+ using idx_t = typename IndexT::idx_t;
22
+ using component_t = typename IndexT::component_t;
23
+ using distance_t = typename IndexT::distance_t;
24
+
25
+ /**
26
+ * The dimension that all sub-indices must share will be the dimension of the
27
+ * first sub-index added
28
+ *
29
+ * @param threaded do we use one thread per sub_index or do
30
+ * queries sequentially?
31
+ * @param successive_ids should we shift the returned ids by
32
+ * the size of each sub-index or return them
33
+ * as they are?
34
+ */
35
+ explicit IndexShardsTemplate(bool threaded = false,
36
+ bool successive_ids = true);
37
+
38
+ /**
39
+ * @param threaded do we use one thread per sub_index or do
40
+ * queries sequentially?
41
+ * @param successive_ids should we shift the returned ids by
42
+ * the size of each sub-index or return them
43
+ * as they are?
44
+ */
45
+ explicit IndexShardsTemplate(idx_t d,
46
+ bool threaded = false,
47
+ bool successive_ids = true);
48
+
49
+ /// int version due to the implicit bool conversion ambiguity of int as
50
+ /// dimension
51
+ explicit IndexShardsTemplate(int d,
52
+ bool threaded = false,
53
+ bool successive_ids = true);
54
+
55
+ /// Alias for addIndex()
56
+ void add_shard(IndexT* index) { this->addIndex(index); }
57
+
58
+ /// Alias for removeIndex()
59
+ void remove_shard(IndexT* index) { this->removeIndex(index); }
60
+
61
+ /// supported only for sub-indices that implement add_with_ids
62
+ void add(idx_t n, const component_t* x) override;
63
+
64
+ /**
65
+ * Cases (successive_ids, xids):
66
+ * - true, non-NULL ERROR: it makes no sense to pass in ids and
67
+ * request them to be shifted
68
+ * - true, NULL OK, but should be called only once (calls add()
69
+ * on sub-indexes).
70
+ * - false, non-NULL OK: will call add_with_ids with passed in xids
71
+ * distributed evenly over shards
72
+ * - false, NULL OK: will call add_with_ids on each sub-index,
73
+ * starting at ntotal
74
+ */
75
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override;
76
+
77
+ void search(idx_t n, const component_t* x, idx_t k,
78
+ distance_t* distances, idx_t* labels) const override;
79
+
80
+ void train(idx_t n, const component_t* x) override;
81
+
82
+ // update metric_type and ntotal. Call if you changes something in
83
+ // the shard indexes.
84
+ void sync_with_shard_indexes();
85
+
86
+ bool successive_ids;
87
+
88
+ protected:
89
+ /// Called just after an index is added
90
+ void onAfterAddIndex(IndexT* index) override;
91
+
92
+ /// Called just after an index is removed
93
+ void onAfterRemoveIndex(IndexT* index) override;
94
+ };
95
+
96
+ using IndexShards = IndexShardsTemplate<Index>;
97
+ using IndexBinaryShards = IndexShardsTemplate<IndexBinary>;
98
+
99
+
100
+ } // namespace faiss