faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,240 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <cstdio>
9
+ #include <cstdlib>
10
+
11
+ #include <memory>
12
+ #include <vector>
13
+
14
+ #include <gtest/gtest.h>
15
+
16
+ #include <faiss/IndexIVF.h>
17
+ #include <faiss/AutoTune.h>
18
+ #include <faiss/index_factory.h>
19
+ #include <faiss/clone_index.h>
20
+ #include <faiss/IVFlib.h>
21
+
22
+ using namespace faiss;
23
+
24
+ typedef Index::idx_t idx_t;
25
+
26
+
27
+ // dimension of the vectors to index
28
+ int d = 32;
29
+
30
+ // nb of training vectors
31
+ size_t nt = 5000;
32
+
33
+ // size of the database points per window step
34
+ size_t nb = 1000;
35
+
36
+ // nb of queries
37
+ size_t nq = 200;
38
+
39
+
40
+ int total_size = 40;
41
+ int window_size = 10;
42
+
43
+
44
+
45
+
46
+
47
+ std::vector<float> make_data(size_t n)
48
+ {
49
+ std::vector <float> database (n * d);
50
+ for (size_t i = 0; i < n * d; i++) {
51
+ database[i] = drand48();
52
+ }
53
+ return database;
54
+ }
55
+
56
+ std::unique_ptr<Index> make_trained_index(const char *index_type)
57
+ {
58
+ auto index = std::unique_ptr<Index>(index_factory(d, index_type));
59
+ auto xt = make_data(nt * d);
60
+ index->train(nt, xt.data());
61
+ ParameterSpace().set_index_parameter (index.get(), "nprobe", 4);
62
+ return index;
63
+ }
64
+
65
+ std::vector<idx_t> search_index(Index *index, const float *xq) {
66
+ int k = 10;
67
+ std::vector<idx_t> I(k * nq);
68
+ std::vector<float> D(k * nq);
69
+ index->search (nq, xq, k, D.data(), I.data());
70
+ return I;
71
+ }
72
+
73
+
74
+
75
+
76
+
77
+ /*************************************************************
78
+ * Test functions for a given index type
79
+ *************************************************************/
80
+
81
+
82
+ // make a few slices of indexes that can be merged
83
+ void make_index_slices (const Index* trained_index,
84
+ std::vector<std::unique_ptr<Index> > & sub_indexes) {
85
+
86
+ for (int i = 0; i < total_size; i++) {
87
+ sub_indexes.emplace_back (clone_index (trained_index));
88
+
89
+ printf ("preparing sub-index # %d\n", i);
90
+
91
+ Index * index = sub_indexes.back().get();
92
+
93
+ auto xb = make_data(nb * d);
94
+ std::vector<faiss::Index::idx_t> ids (nb);
95
+ for (int j = 0; j < nb; j++) {
96
+ ids[j] = lrand48();
97
+ }
98
+ index->add_with_ids (nb, xb.data(), ids.data());
99
+ }
100
+
101
+ }
102
+
103
+ // build merged index explicitly at sliding window position i
104
+ Index *make_merged_index(
105
+ const Index* trained_index,
106
+ const std::vector<std::unique_ptr<Index> > & sub_indexes,
107
+ int i) {
108
+
109
+ Index * merged_index = clone_index (trained_index);
110
+ for (int j = i - window_size + 1; j <= i; j++) {
111
+ if (j < 0 || j >= total_size) continue;
112
+ std::unique_ptr<Index> sub_index (
113
+ clone_index (sub_indexes[j].get()));
114
+ IndexIVF *ivf0 = ivflib::extract_index_ivf (merged_index);
115
+ IndexIVF *ivf1 = ivflib::extract_index_ivf (sub_index.get());
116
+ ivf0->merge_from (*ivf1, 0);
117
+ merged_index->ntotal = ivf0->ntotal;
118
+ }
119
+ return merged_index;
120
+ }
121
+
122
+ int test_sliding_window (const char *index_key) {
123
+
124
+ std::unique_ptr<Index> trained_index = make_trained_index(index_key);
125
+
126
+ // make the index slices
127
+ std::vector<std::unique_ptr<Index> > sub_indexes;
128
+
129
+ make_index_slices (trained_index.get(), sub_indexes);
130
+
131
+ // now slide over the windows
132
+ std::unique_ptr<Index> index (clone_index (trained_index.get()));
133
+ ivflib::SlidingIndexWindow window (index.get());
134
+
135
+ auto xq = make_data (nq * d);
136
+
137
+ for (int i = 0; i < total_size + window_size; i++) {
138
+
139
+ printf ("doing step %d / %d\n", i, total_size + window_size);
140
+
141
+ // update the index
142
+ window.step (i < total_size ? sub_indexes[i].get() : nullptr,
143
+ i >= window_size);
144
+ printf (" current n_slice = %d\n", window.n_slice);
145
+
146
+ auto new_res = search_index (index.get(), xq.data());
147
+
148
+ std::unique_ptr<Index> merged_index (
149
+ make_merged_index (trained_index.get(), sub_indexes, i));
150
+
151
+ auto ref_res = search_index (merged_index.get(), xq.data ());
152
+
153
+ EXPECT_EQ (ref_res.size(), new_res.size());
154
+
155
+ EXPECT_EQ (ref_res, new_res);
156
+ }
157
+ return 0;
158
+ }
159
+
160
+
161
+ int test_sliding_invlists (const char *index_key) {
162
+
163
+ std::unique_ptr<Index> trained_index = make_trained_index(index_key);
164
+
165
+ // make the index slices
166
+ std::vector<std::unique_ptr<Index> > sub_indexes;
167
+
168
+ make_index_slices (trained_index.get(), sub_indexes);
169
+
170
+ // now slide over the windows
171
+ std::unique_ptr<Index> index (clone_index (trained_index.get()));
172
+ IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get());
173
+
174
+ auto xq = make_data (nq * d);
175
+
176
+ for (int i = 0; i < total_size + window_size; i++) {
177
+
178
+ printf ("doing step %d / %d\n", i, total_size + window_size);
179
+
180
+ // update the index
181
+ std::vector<const InvertedLists*> ils;
182
+ for (int j = i - window_size + 1; j <= i; j++) {
183
+ if (j < 0 || j >= total_size) continue;
184
+ ils.push_back (ivflib::extract_index_ivf (
185
+ sub_indexes[j].get())->invlists);
186
+ }
187
+ if (ils.size() == 0) continue;
188
+
189
+ ConcatenatedInvertedLists *ci =
190
+ new ConcatenatedInvertedLists (ils.size(), ils.data());
191
+
192
+ // will be deleted by the index
193
+ index_ivf->replace_invlists (ci, true);
194
+
195
+ printf (" nb invlists = %ld\n", ils.size());
196
+
197
+ auto new_res = search_index (index.get(), xq.data());
198
+
199
+ std::unique_ptr<Index> merged_index (
200
+ make_merged_index (trained_index.get(), sub_indexes, i));
201
+
202
+ auto ref_res = search_index (merged_index.get(), xq.data ());
203
+
204
+ EXPECT_EQ (ref_res.size(), new_res.size());
205
+
206
+ size_t ndiff = 0;
207
+ for (size_t j = 0; j < ref_res.size(); j++) {
208
+ if (ref_res[j] != new_res[j])
209
+ ndiff++;
210
+ }
211
+ printf(" nb differences: %ld / %ld\n",
212
+ ndiff, ref_res.size());
213
+ EXPECT_EQ (ref_res, new_res);
214
+ }
215
+ return 0;
216
+ }
217
+
218
+
219
+
220
+
221
+
222
+ /*************************************************************
223
+ * Test entry points
224
+ *************************************************************/
225
+
226
+ TEST(SlidingWindow, IVFFlat) {
227
+ test_sliding_window ("IVF32,Flat");
228
+ }
229
+
230
+ TEST(SlidingWindow, PCAIVFFlat) {
231
+ test_sliding_window ("PCA24,IVF32,Flat");
232
+ }
233
+
234
+ TEST(SlidingInvlists, IVFFlat) {
235
+ test_sliding_invlists ("IVF32,Flat");
236
+ }
237
+
238
+ TEST(SlidingInvlists, PCAIVFFlat) {
239
+ test_sliding_invlists ("PCA24,IVF32,Flat");
240
+ }
@@ -0,0 +1,253 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/ThreadedIndex.h>
9
+ #include <faiss/IndexReplicas.h>
10
+ #include <faiss/IndexShards.h>
11
+
12
+ #include <chrono>
13
+ #include <gtest/gtest.h>
14
+ #include <memory>
15
+ #include <vector>
16
+ #include <thread>
17
+
18
+ namespace {
19
+
20
+ struct TestException : public std::exception { };
21
+
22
+ struct MockIndex : public faiss::Index {
23
+ explicit MockIndex(idx_t d) :
24
+ faiss::Index(d) {
25
+ resetMock();
26
+ }
27
+
28
+ void resetMock() {
29
+ flag = false;
30
+ nCalled = 0;
31
+ xCalled = nullptr;
32
+ kCalled = 0;
33
+ distancesCalled = nullptr;
34
+ labelsCalled = nullptr;
35
+ }
36
+
37
+ void add(idx_t n, const float* x) override {
38
+ nCalled = n;
39
+ xCalled = x;
40
+ }
41
+
42
+ void search(idx_t n,
43
+ const float* x,
44
+ idx_t k,
45
+ float* distances,
46
+ idx_t* labels) const override {
47
+ nCalled = n;
48
+ xCalled = x;
49
+ kCalled = k;
50
+ distancesCalled = distances;
51
+ labelsCalled = labels;
52
+ }
53
+
54
+ void reset() override { }
55
+
56
+ bool flag;
57
+
58
+ mutable idx_t nCalled;
59
+ mutable const float* xCalled;
60
+ mutable idx_t kCalled;
61
+ mutable float* distancesCalled;
62
+ mutable idx_t* labelsCalled;
63
+ };
64
+
65
+ template <typename IndexT>
66
+ struct MockThreadedIndex : public faiss::ThreadedIndex<IndexT> {
67
+ using idx_t = faiss::Index::idx_t;
68
+
69
+ explicit MockThreadedIndex(bool threaded)
70
+ : faiss::ThreadedIndex<IndexT>(threaded) {
71
+ }
72
+
73
+ void add(idx_t, const float*) override { }
74
+ void search(idx_t, const float*, idx_t, float*, idx_t*) const override {}
75
+ void reset() override {}
76
+ };
77
+
78
+ }
79
+
80
+ TEST(ThreadedIndex, SingleException) {
81
+ std::vector<std::unique_ptr<MockIndex>> idxs;
82
+
83
+ for (int i = 0; i < 3; ++i) {
84
+ idxs.emplace_back(new MockIndex(1));
85
+ }
86
+
87
+ auto fn =
88
+ [](int i, MockIndex* index) {
89
+ if (i == 1) {
90
+ throw TestException();
91
+ } else {
92
+ std::this_thread::sleep_for(std::chrono::milliseconds(i * 250));
93
+
94
+ index->flag = true;
95
+ }
96
+ };
97
+
98
+ // Try with threading and without
99
+ for (bool threaded : {true, false}) {
100
+ // clear flags
101
+ for (auto& idx : idxs) {
102
+ idx->resetMock();
103
+ }
104
+
105
+ MockThreadedIndex<MockIndex> ti(threaded);
106
+ for (auto& idx : idxs) {
107
+ ti.addIndex(idx.get());
108
+ }
109
+
110
+ // The second index should throw
111
+ EXPECT_THROW(ti.runOnIndex(fn), TestException);
112
+
113
+ // Index 0 and 2 should have processed
114
+ EXPECT_TRUE(idxs[0]->flag);
115
+ EXPECT_TRUE(idxs[2]->flag);
116
+ }
117
+ }
118
+
119
+ TEST(ThreadedIndex, MultipleException) {
120
+ std::vector<std::unique_ptr<MockIndex>> idxs;
121
+
122
+ for (int i = 0; i < 3; ++i) {
123
+ idxs.emplace_back(new MockIndex(1));
124
+ }
125
+
126
+ auto fn =
127
+ [](int i, MockIndex* index) {
128
+ if (i < 2) {
129
+ throw TestException();
130
+ } else {
131
+ std::this_thread::sleep_for(std::chrono::milliseconds(i * 250));
132
+
133
+ index->flag = true;
134
+ }
135
+ };
136
+
137
+ // Try with threading and without
138
+ for (bool threaded : {true, false}) {
139
+ // clear flags
140
+ for (auto& idx : idxs) {
141
+ idx->resetMock();
142
+ }
143
+
144
+ MockThreadedIndex<MockIndex> ti(threaded);
145
+ for (auto& idx : idxs) {
146
+ ti.addIndex(idx.get());
147
+ }
148
+
149
+ // Multiple indices threw an exception that was aggregated into a
150
+ // FaissException
151
+ EXPECT_THROW(ti.runOnIndex(fn), faiss::FaissException);
152
+
153
+ // Index 2 should have processed
154
+ EXPECT_TRUE(idxs[2]->flag);
155
+ }
156
+ }
157
+
158
+ TEST(ThreadedIndex, TestReplica) {
159
+ int numReplicas = 5;
160
+ int n = 10 * numReplicas;
161
+ int d = 3;
162
+ int k = 6;
163
+
164
+ // Try with threading and without
165
+ for (bool threaded : {true, false}) {
166
+ std::vector<std::unique_ptr<MockIndex>> idxs;
167
+ faiss::IndexReplicas replica(d);
168
+
169
+ for (int i = 0; i < numReplicas; ++i) {
170
+ idxs.emplace_back(new MockIndex(d));
171
+ replica.addIndex(idxs.back().get());
172
+ }
173
+
174
+ std::vector<float> x(n * d);
175
+ std::vector<float> distances(n * k);
176
+ std::vector<faiss::Index::idx_t> labels(n * k);
177
+
178
+ replica.add(n, x.data());
179
+
180
+ for (int i = 0; i < idxs.size(); ++i) {
181
+ EXPECT_EQ(idxs[i]->nCalled, n);
182
+ EXPECT_EQ(idxs[i]->xCalled, x.data());
183
+ }
184
+
185
+ for (auto& idx : idxs) {
186
+ idx->resetMock();
187
+ }
188
+
189
+ replica.search(n, x.data(), k, distances.data(), labels.data());
190
+
191
+ for (int i = 0; i < idxs.size(); ++i) {
192
+ auto perReplica = n / idxs.size();
193
+
194
+ EXPECT_EQ(idxs[i]->nCalled, perReplica);
195
+ EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perReplica * d);
196
+ EXPECT_EQ(idxs[i]->kCalled, k);
197
+ EXPECT_EQ(idxs[i]->distancesCalled,
198
+ distances.data() + (i * perReplica) * k);
199
+ EXPECT_EQ(idxs[i]->labelsCalled,
200
+ labels.data() + (i * perReplica) * k);
201
+ }
202
+ }
203
+ }
204
+
205
+ TEST(ThreadedIndex, TestShards) {
206
+ int numShards = 7;
207
+ int d = 3;
208
+ int n = 10 * numShards;
209
+ int k = 6;
210
+
211
+ // Try with threading and without
212
+ for (bool threaded : {true, false}) {
213
+ std::vector<std::unique_ptr<MockIndex>> idxs;
214
+ faiss::IndexShards shards(d, threaded);
215
+
216
+ for (int i = 0; i < numShards; ++i) {
217
+ idxs.emplace_back(new MockIndex(d));
218
+ shards.addIndex(idxs.back().get());
219
+ }
220
+
221
+ std::vector<float> x(n * d);
222
+ std::vector<float> distances(n * k);
223
+ std::vector<faiss::Index::idx_t> labels(n * k);
224
+
225
+ shards.add(n, x.data());
226
+
227
+ for (int i = 0; i < idxs.size(); ++i) {
228
+ auto perShard = n / idxs.size();
229
+
230
+ EXPECT_EQ(idxs[i]->nCalled, perShard);
231
+ EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perShard * d);
232
+ }
233
+
234
+ for (auto& idx : idxs) {
235
+ idx->resetMock();
236
+ }
237
+
238
+ shards.search(n, x.data(), k, distances.data(), labels.data());
239
+
240
+ for (int i = 0; i < idxs.size(); ++i) {
241
+ auto perShard = n / idxs.size();
242
+
243
+ EXPECT_EQ(idxs[i]->nCalled, n);
244
+ EXPECT_EQ(idxs[i]->xCalled, x.data());
245
+ EXPECT_EQ(idxs[i]->kCalled, k);
246
+ // There is a temporary buffer used for shards
247
+ EXPECT_EQ(idxs[i]->distancesCalled,
248
+ idxs[0]->distancesCalled + i * k * n);
249
+ EXPECT_EQ(idxs[i]->labelsCalled,
250
+ idxs[0]->labelsCalled + i * k * n);
251
+ }
252
+ }
253
+ }