faiss 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -1,246 +0,0 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
3
- *
4
- * This source code is licensed under the MIT license found in the
5
- * LICENSE file in the root directory of this source tree.
6
- */
7
-
8
- #include <cstdio>
9
- #include <cstdlib>
10
-
11
- #include <memory>
12
- #include <random>
13
- #include <vector>
14
-
15
- #include <gtest/gtest.h>
16
-
17
- #include <faiss/IndexIVF.h>
18
- #include <faiss/AutoTune.h>
19
- #include <faiss/index_factory.h>
20
- #include <faiss/clone_index.h>
21
- #include <faiss/IVFlib.h>
22
-
23
-
24
- using namespace faiss;
25
-
26
- typedef Index::idx_t idx_t;
27
-
28
-
29
- // dimension of the vectors to index
30
- int d = 32;
31
-
32
- // nb of training vectors
33
- size_t nt = 5000;
34
-
35
- // size of the database points per window step
36
- size_t nb = 1000;
37
-
38
- // nb of queries
39
- size_t nq = 200;
40
-
41
-
42
- int total_size = 40;
43
- int window_size = 10;
44
-
45
-
46
-
47
-
48
-
49
- std::vector<float> make_data(size_t n)
50
- {
51
- std::vector <float> database (n * d);
52
- std::mt19937 rng;
53
- std::uniform_real_distribution<> distrib;
54
- for (size_t i = 0; i < n * d; i++) {
55
- database[i] = distrib(rng);
56
- }
57
- return database;
58
- }
59
-
60
- std::unique_ptr<Index> make_trained_index(const char *index_type)
61
- {
62
- auto index = std::unique_ptr<Index>(index_factory(d, index_type));
63
- auto xt = make_data(nt * d);
64
- index->train(nt, xt.data());
65
- ParameterSpace().set_index_parameter (index.get(), "nprobe", 4);
66
- return index;
67
- }
68
-
69
- std::vector<idx_t> search_index(Index *index, const float *xq) {
70
- int k = 10;
71
- std::vector<idx_t> I(k * nq);
72
- std::vector<float> D(k * nq);
73
- index->search (nq, xq, k, D.data(), I.data());
74
- return I;
75
- }
76
-
77
-
78
-
79
-
80
-
81
- /*************************************************************
82
- * Test functions for a given index type
83
- *************************************************************/
84
-
85
-
86
- // make a few slices of indexes that can be merged
87
- void make_index_slices (const Index* trained_index,
88
- std::vector<std::unique_ptr<Index> > & sub_indexes) {
89
-
90
- for (int i = 0; i < total_size; i++) {
91
- sub_indexes.emplace_back (clone_index (trained_index));
92
-
93
- printf ("preparing sub-index # %d\n", i);
94
-
95
- Index * index = sub_indexes.back().get();
96
-
97
- auto xb = make_data(nb * d);
98
- std::vector<faiss::Index::idx_t> ids (nb);
99
- std::mt19937 rng;
100
- std::uniform_int_distribution<> distrib;
101
- for (int j = 0; j < nb; j++) {
102
- ids[j] = distrib(rng);
103
- }
104
- index->add_with_ids (nb, xb.data(), ids.data());
105
- }
106
-
107
- }
108
-
109
- // build merged index explicitly at sliding window position i
110
- Index *make_merged_index(
111
- const Index* trained_index,
112
- const std::vector<std::unique_ptr<Index> > & sub_indexes,
113
- int i) {
114
-
115
- Index * merged_index = clone_index (trained_index);
116
- for (int j = i - window_size + 1; j <= i; j++) {
117
- if (j < 0 || j >= total_size) continue;
118
- std::unique_ptr<Index> sub_index (
119
- clone_index (sub_indexes[j].get()));
120
- IndexIVF *ivf0 = ivflib::extract_index_ivf (merged_index);
121
- IndexIVF *ivf1 = ivflib::extract_index_ivf (sub_index.get());
122
- ivf0->merge_from (*ivf1, 0);
123
- merged_index->ntotal = ivf0->ntotal;
124
- }
125
- return merged_index;
126
- }
127
-
128
- int test_sliding_window (const char *index_key) {
129
-
130
- std::unique_ptr<Index> trained_index = make_trained_index(index_key);
131
-
132
- // make the index slices
133
- std::vector<std::unique_ptr<Index> > sub_indexes;
134
-
135
- make_index_slices (trained_index.get(), sub_indexes);
136
-
137
- // now slide over the windows
138
- std::unique_ptr<Index> index (clone_index (trained_index.get()));
139
- ivflib::SlidingIndexWindow window (index.get());
140
-
141
- auto xq = make_data (nq * d);
142
-
143
- for (int i = 0; i < total_size + window_size; i++) {
144
-
145
- printf ("doing step %d / %d\n", i, total_size + window_size);
146
-
147
- // update the index
148
- window.step (i < total_size ? sub_indexes[i].get() : nullptr,
149
- i >= window_size);
150
- printf (" current n_slice = %d\n", window.n_slice);
151
-
152
- auto new_res = search_index (index.get(), xq.data());
153
-
154
- std::unique_ptr<Index> merged_index (
155
- make_merged_index (trained_index.get(), sub_indexes, i));
156
-
157
- auto ref_res = search_index (merged_index.get(), xq.data ());
158
-
159
- EXPECT_EQ (ref_res.size(), new_res.size());
160
-
161
- EXPECT_EQ (ref_res, new_res);
162
- }
163
- return 0;
164
- }
165
-
166
-
167
- int test_sliding_invlists (const char *index_key) {
168
-
169
- std::unique_ptr<Index> trained_index = make_trained_index(index_key);
170
-
171
- // make the index slices
172
- std::vector<std::unique_ptr<Index> > sub_indexes;
173
-
174
- make_index_slices (trained_index.get(), sub_indexes);
175
-
176
- // now slide over the windows
177
- std::unique_ptr<Index> index (clone_index (trained_index.get()));
178
- IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get());
179
-
180
- auto xq = make_data (nq * d);
181
-
182
- for (int i = 0; i < total_size + window_size; i++) {
183
-
184
- printf ("doing step %d / %d\n", i, total_size + window_size);
185
-
186
- // update the index
187
- std::vector<const InvertedLists*> ils;
188
- for (int j = i - window_size + 1; j <= i; j++) {
189
- if (j < 0 || j >= total_size) continue;
190
- ils.push_back (ivflib::extract_index_ivf (
191
- sub_indexes[j].get())->invlists);
192
- }
193
- if (ils.size() == 0) continue;
194
-
195
- ConcatenatedInvertedLists *ci =
196
- new ConcatenatedInvertedLists (ils.size(), ils.data());
197
-
198
- // will be deleted by the index
199
- index_ivf->replace_invlists (ci, true);
200
-
201
- printf (" nb invlists = %zd\n", ils.size());
202
-
203
- auto new_res = search_index (index.get(), xq.data());
204
-
205
- std::unique_ptr<Index> merged_index (
206
- make_merged_index (trained_index.get(), sub_indexes, i));
207
-
208
- auto ref_res = search_index (merged_index.get(), xq.data ());
209
-
210
- EXPECT_EQ (ref_res.size(), new_res.size());
211
-
212
- size_t ndiff = 0;
213
- for (size_t j = 0; j < ref_res.size(); j++) {
214
- if (ref_res[j] != new_res[j])
215
- ndiff++;
216
- }
217
- printf(" nb differences: %zd / %zd\n",
218
- ndiff, ref_res.size());
219
- EXPECT_EQ (ref_res, new_res);
220
- }
221
- return 0;
222
- }
223
-
224
-
225
-
226
-
227
-
228
- /*************************************************************
229
- * Test entry points
230
- *************************************************************/
231
-
232
- TEST(SlidingWindow, IVFFlat) {
233
- test_sliding_window ("IVF32,Flat");
234
- }
235
-
236
- TEST(SlidingWindow, PCAIVFFlat) {
237
- test_sliding_window ("PCA24,IVF32,Flat");
238
- }
239
-
240
- TEST(SlidingInvlists, IVFFlat) {
241
- test_sliding_invlists ("IVF32,Flat");
242
- }
243
-
244
- TEST(SlidingInvlists, PCAIVFFlat) {
245
- test_sliding_invlists ("PCA24,IVF32,Flat");
246
- }
@@ -1,253 +0,0 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
3
- *
4
- * This source code is licensed under the MIT license found in the
5
- * LICENSE file in the root directory of this source tree.
6
- */
7
-
8
- #include <faiss/impl/ThreadedIndex.h>
9
- #include <faiss/IndexReplicas.h>
10
- #include <faiss/IndexShards.h>
11
-
12
- #include <chrono>
13
- #include <gtest/gtest.h>
14
- #include <memory>
15
- #include <vector>
16
- #include <thread>
17
-
18
- namespace {
19
-
20
- struct TestException : public std::exception { };
21
-
22
- struct MockIndex : public faiss::Index {
23
- explicit MockIndex(idx_t d) :
24
- faiss::Index(d) {
25
- resetMock();
26
- }
27
-
28
- void resetMock() {
29
- flag = false;
30
- nCalled = 0;
31
- xCalled = nullptr;
32
- kCalled = 0;
33
- distancesCalled = nullptr;
34
- labelsCalled = nullptr;
35
- }
36
-
37
- void add(idx_t n, const float* x) override {
38
- nCalled = n;
39
- xCalled = x;
40
- }
41
-
42
- void search(idx_t n,
43
- const float* x,
44
- idx_t k,
45
- float* distances,
46
- idx_t* labels) const override {
47
- nCalled = n;
48
- xCalled = x;
49
- kCalled = k;
50
- distancesCalled = distances;
51
- labelsCalled = labels;
52
- }
53
-
54
- void reset() override { }
55
-
56
- bool flag;
57
-
58
- mutable idx_t nCalled;
59
- mutable const float* xCalled;
60
- mutable idx_t kCalled;
61
- mutable float* distancesCalled;
62
- mutable idx_t* labelsCalled;
63
- };
64
-
65
- template <typename IndexT>
66
- struct MockThreadedIndex : public faiss::ThreadedIndex<IndexT> {
67
- using idx_t = faiss::Index::idx_t;
68
-
69
- explicit MockThreadedIndex(bool threaded)
70
- : faiss::ThreadedIndex<IndexT>(threaded) {
71
- }
72
-
73
- void add(idx_t, const float*) override { }
74
- void search(idx_t, const float*, idx_t, float*, idx_t*) const override {}
75
- void reset() override {}
76
- };
77
-
78
- }
79
-
80
- TEST(ThreadedIndex, SingleException) {
81
- std::vector<std::unique_ptr<MockIndex>> idxs;
82
-
83
- for (int i = 0; i < 3; ++i) {
84
- idxs.emplace_back(new MockIndex(1));
85
- }
86
-
87
- auto fn =
88
- [](int i, MockIndex* index) {
89
- if (i == 1) {
90
- throw TestException();
91
- } else {
92
- std::this_thread::sleep_for(std::chrono::milliseconds(i * 250));
93
-
94
- index->flag = true;
95
- }
96
- };
97
-
98
- // Try with threading and without
99
- for (bool threaded : {true, false}) {
100
- // clear flags
101
- for (auto& idx : idxs) {
102
- idx->resetMock();
103
- }
104
-
105
- MockThreadedIndex<MockIndex> ti(threaded);
106
- for (auto& idx : idxs) {
107
- ti.addIndex(idx.get());
108
- }
109
-
110
- // The second index should throw
111
- EXPECT_THROW(ti.runOnIndex(fn), TestException);
112
-
113
- // Index 0 and 2 should have processed
114
- EXPECT_TRUE(idxs[0]->flag);
115
- EXPECT_TRUE(idxs[2]->flag);
116
- }
117
- }
118
-
119
- TEST(ThreadedIndex, MultipleException) {
120
- std::vector<std::unique_ptr<MockIndex>> idxs;
121
-
122
- for (int i = 0; i < 3; ++i) {
123
- idxs.emplace_back(new MockIndex(1));
124
- }
125
-
126
- auto fn =
127
- [](int i, MockIndex* index) {
128
- if (i < 2) {
129
- throw TestException();
130
- } else {
131
- std::this_thread::sleep_for(std::chrono::milliseconds(i * 250));
132
-
133
- index->flag = true;
134
- }
135
- };
136
-
137
- // Try with threading and without
138
- for (bool threaded : {true, false}) {
139
- // clear flags
140
- for (auto& idx : idxs) {
141
- idx->resetMock();
142
- }
143
-
144
- MockThreadedIndex<MockIndex> ti(threaded);
145
- for (auto& idx : idxs) {
146
- ti.addIndex(idx.get());
147
- }
148
-
149
- // Multiple indices threw an exception that was aggregated into a
150
- // FaissException
151
- EXPECT_THROW(ti.runOnIndex(fn), faiss::FaissException);
152
-
153
- // Index 2 should have processed
154
- EXPECT_TRUE(idxs[2]->flag);
155
- }
156
- }
157
-
158
- TEST(ThreadedIndex, TestReplica) {
159
- int numReplicas = 5;
160
- int n = 10 * numReplicas;
161
- int d = 3;
162
- int k = 6;
163
-
164
- // Try with threading and without
165
- for (bool threaded : {true, false}) {
166
- std::vector<std::unique_ptr<MockIndex>> idxs;
167
- faiss::IndexReplicas replica(d);
168
-
169
- for (int i = 0; i < numReplicas; ++i) {
170
- idxs.emplace_back(new MockIndex(d));
171
- replica.addIndex(idxs.back().get());
172
- }
173
-
174
- std::vector<float> x(n * d);
175
- std::vector<float> distances(n * k);
176
- std::vector<faiss::Index::idx_t> labels(n * k);
177
-
178
- replica.add(n, x.data());
179
-
180
- for (int i = 0; i < idxs.size(); ++i) {
181
- EXPECT_EQ(idxs[i]->nCalled, n);
182
- EXPECT_EQ(idxs[i]->xCalled, x.data());
183
- }
184
-
185
- for (auto& idx : idxs) {
186
- idx->resetMock();
187
- }
188
-
189
- replica.search(n, x.data(), k, distances.data(), labels.data());
190
-
191
- for (int i = 0; i < idxs.size(); ++i) {
192
- auto perReplica = n / idxs.size();
193
-
194
- EXPECT_EQ(idxs[i]->nCalled, perReplica);
195
- EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perReplica * d);
196
- EXPECT_EQ(idxs[i]->kCalled, k);
197
- EXPECT_EQ(idxs[i]->distancesCalled,
198
- distances.data() + (i * perReplica) * k);
199
- EXPECT_EQ(idxs[i]->labelsCalled,
200
- labels.data() + (i * perReplica) * k);
201
- }
202
- }
203
- }
204
-
205
- TEST(ThreadedIndex, TestShards) {
206
- int numShards = 7;
207
- int d = 3;
208
- int n = 10 * numShards;
209
- int k = 6;
210
-
211
- // Try with threading and without
212
- for (bool threaded : {true, false}) {
213
- std::vector<std::unique_ptr<MockIndex>> idxs;
214
- faiss::IndexShards shards(d, threaded);
215
-
216
- for (int i = 0; i < numShards; ++i) {
217
- idxs.emplace_back(new MockIndex(d));
218
- shards.addIndex(idxs.back().get());
219
- }
220
-
221
- std::vector<float> x(n * d);
222
- std::vector<float> distances(n * k);
223
- std::vector<faiss::Index::idx_t> labels(n * k);
224
-
225
- shards.add(n, x.data());
226
-
227
- for (int i = 0; i < idxs.size(); ++i) {
228
- auto perShard = n / idxs.size();
229
-
230
- EXPECT_EQ(idxs[i]->nCalled, perShard);
231
- EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perShard * d);
232
- }
233
-
234
- for (auto& idx : idxs) {
235
- idx->resetMock();
236
- }
237
-
238
- shards.search(n, x.data(), k, distances.data(), labels.data());
239
-
240
- for (int i = 0; i < idxs.size(); ++i) {
241
- auto perShard = n / idxs.size();
242
-
243
- EXPECT_EQ(idxs[i]->nCalled, n);
244
- EXPECT_EQ(idxs[i]->xCalled, x.data());
245
- EXPECT_EQ(idxs[i]->kCalled, k);
246
- // There is a temporary buffer used for shards
247
- EXPECT_EQ(idxs[i]->distancesCalled,
248
- idxs[0]->distancesCalled + i * k * n);
249
- EXPECT_EQ(idxs[i]->labelsCalled,
250
- idxs[0]->labelsCalled + i * k * n);
251
- }
252
- }
253
- }