faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -1,193 +0,0 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
3
- *
4
- * This source code is licensed under the MIT license found in the
5
- * LICENSE file in the root directory of this source tree.
6
- */
7
-
8
- #include <cstdio>
9
- #include <cstdlib>
10
-
11
- #include <memory>
12
- #include <vector>
13
- #include <random>
14
-
15
- #include <gtest/gtest.h>
16
-
17
- #include <faiss/IndexIVF.h>
18
- #include <faiss/index_factory.h>
19
- #include <faiss/VectorTransform.h>
20
- #include <faiss/IVFlib.h>
21
-
22
-
23
- namespace {
24
-
25
- typedef faiss::Index::idx_t idx_t;
26
-
27
- /*************************************************************
28
- * Test utils
29
- *************************************************************/
30
-
31
-
32
- // dimension of the vectors to index
33
- int d = 64;
34
-
35
- // size of the database we plan to index
36
- size_t nb = 8000;
37
-
38
- // nb of queries
39
- size_t nq = 200;
40
-
41
- std::mt19937 rng;
42
-
43
- std::vector<float> make_data(size_t n)
44
- {
45
- std::vector <float> database (n * d);
46
- std::uniform_real_distribution<> distrib;
47
- for (size_t i = 0; i < n * d; i++) {
48
- database[i] = distrib(rng);
49
- }
50
- return database;
51
- }
52
-
53
- std::unique_ptr<faiss::Index> make_index(const char *index_type,
54
- const std::vector<float> & x) {
55
-
56
- auto index = std::unique_ptr<faiss::Index> (
57
- faiss::index_factory(d, index_type));
58
- index->train(nb, x.data());
59
- index->add(nb, x.data());
60
- return index;
61
- }
62
-
63
- /*************************************************************
64
- * Test functions for a given index type
65
- *************************************************************/
66
-
67
- bool test_search_centroid(const char *index_key) {
68
- std::vector<float> xb = make_data(nb); // database vectors
69
- auto index = make_index(index_key, xb);
70
-
71
- /* First test: find the centroids associated to the database
72
- vectors and make sure that each vector does indeed appear in
73
- the inverted list corresponding to its centroid */
74
-
75
- std::vector<idx_t> centroid_ids (nb);
76
- faiss::ivflib::search_centroid(
77
- index.get(), xb.data(), nb, centroid_ids.data());
78
-
79
- const faiss::IndexIVF * ivf = faiss::ivflib::extract_index_ivf
80
- (index.get());
81
-
82
- for(int i = 0; i < nb; i++) {
83
- bool found = false;
84
- int list_no = centroid_ids[i];
85
- int list_size = ivf->invlists->list_size (list_no);
86
- auto * list = ivf->invlists->get_ids (list_no);
87
-
88
- for(int j = 0; j < list_size; j++) {
89
- if (list[j] == i) {
90
- found = true;
91
- break;
92
- }
93
- }
94
- if(!found) return false;
95
- }
96
- return true;
97
- }
98
-
99
- int test_search_and_return_centroids(const char *index_key) {
100
- std::vector<float> xb = make_data(nb); // database vectors
101
- auto index = make_index(index_key, xb);
102
-
103
- std::vector<idx_t> centroid_ids (nb);
104
- faiss::ivflib::search_centroid(index.get(), xb.data(),
105
- nb, centroid_ids.data());
106
-
107
- faiss::IndexIVF * ivf =
108
- faiss::ivflib::extract_index_ivf (index.get());
109
- ivf->nprobe = 4;
110
-
111
- std::vector<float> xq = make_data(nq); // database vectors
112
-
113
- int k = 5;
114
-
115
- // compute a reference search result
116
-
117
- std::vector<idx_t> refI (nq * k);
118
- std::vector<float> refD (nq * k);
119
- index->search (nq, xq.data(), k, refD.data(), refI.data());
120
-
121
- // compute search result
122
-
123
- std::vector<idx_t> newI (nq * k);
124
- std::vector<float> newD (nq * k);
125
-
126
- std::vector<idx_t> query_centroid_ids (nq);
127
- std::vector<idx_t> result_centroid_ids (nq * k);
128
-
129
- faiss::ivflib::search_and_return_centroids(index.get(),
130
- nq, xq.data(), k,
131
- newD.data(), newI.data(),
132
- query_centroid_ids.data(),
133
- result_centroid_ids.data());
134
-
135
- // first verify that we have the same result as the standard search
136
-
137
- if (newI != refI) {
138
- return 1;
139
- }
140
-
141
- // then check if the result ids are indeed in the inverted list
142
- // they are supposed to be in
143
-
144
- for(int i = 0; i < nq * k; i++) {
145
- int list_no = result_centroid_ids[i];
146
- int result_no = newI[i];
147
-
148
- if (result_no < 0) continue;
149
-
150
- bool found = false;
151
-
152
- int list_size = ivf->invlists->list_size (list_no);
153
- auto * list = ivf->invlists->get_ids (list_no);
154
-
155
- for(int j = 0; j < list_size; j++) {
156
- if (list[j] == result_no) {
157
- found = true;
158
- break;
159
- }
160
- }
161
- if(!found) return 2;
162
- }
163
- return 0;
164
- }
165
-
166
- } // namespace
167
-
168
-
169
- /*************************************************************
170
- * Test entry points
171
- *************************************************************/
172
-
173
- TEST(test_search_centroid, IVFFlat) {
174
- bool ok = test_search_centroid("IVF32,Flat");
175
- EXPECT_TRUE(ok);
176
- }
177
-
178
- TEST(test_search_centroid, PCAIVFFlat) {
179
- bool ok = test_search_centroid("PCA16,IVF32,Flat");
180
- EXPECT_TRUE(ok);
181
- }
182
-
183
- TEST(test_search_and_return_centroids, IVFFlat) {
184
- int err = test_search_and_return_centroids("IVF32,Flat");
185
- EXPECT_NE(err, 1);
186
- EXPECT_NE(err, 2);
187
- }
188
-
189
- TEST(test_search_and_return_centroids, PCAIVFFlat) {
190
- int err = test_search_and_return_centroids("PCA16,IVF32,Flat");
191
- EXPECT_NE(err, 1);
192
- EXPECT_NE(err, 2);
193
- }
@@ -1,236 +0,0 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
3
- *
4
- * This source code is licensed under the MIT license found in the
5
- * LICENSE file in the root directory of this source tree.
6
- */
7
-
8
- #include <cstdio>
9
- #include <cstdlib>
10
-
11
- #include <memory>
12
- #include <vector>
13
- #include <random>
14
-
15
- #include <gtest/gtest.h>
16
-
17
- #include <faiss/IndexIVF.h>
18
- #include <faiss/IndexBinaryIVF.h>
19
- #include <faiss/index_factory.h>
20
- #include <faiss/AutoTune.h>
21
- #include <faiss/IVFlib.h>
22
-
23
-
24
- using namespace faiss;
25
-
26
- namespace {
27
-
28
- typedef Index::idx_t idx_t;
29
-
30
-
31
- // dimension of the vectors to index
32
- int d = 32;
33
-
34
- // size of the database we plan to index
35
- size_t nb = 1000;
36
-
37
- // nb of queries
38
- size_t nq = 200;
39
-
40
- std::mt19937 rng;
41
-
42
-
43
- std::vector<float> make_data(size_t n)
44
- {
45
- std::vector <float> database (n * d);
46
- std::uniform_real_distribution<> distrib;
47
- for (size_t i = 0; i < n * d; i++) {
48
- database[i] = distrib(rng);
49
- }
50
- return database;
51
- }
52
-
53
- std::unique_ptr<Index> make_index(const char *index_type,
54
- MetricType metric,
55
- const std::vector<float> & x)
56
- {
57
- std::unique_ptr<Index> index(index_factory(d, index_type, metric));
58
- index->train(nb, x.data());
59
- index->add(nb, x.data());
60
- return index;
61
- }
62
-
63
- std::vector<idx_t> search_index(Index *index, const float *xq) {
64
- int k = 10;
65
- std::vector<idx_t> I(k * nq);
66
- std::vector<float> D(k * nq);
67
- index->search (nq, xq, k, D.data(), I.data());
68
- return I;
69
- }
70
-
71
- std::vector<idx_t> search_index_with_params(
72
- Index *index, const float *xq, IVFSearchParameters *params) {
73
- int k = 10;
74
- std::vector<idx_t> I(k * nq);
75
- std::vector<float> D(k * nq);
76
- ivflib::search_with_parameters (index, nq, xq, k,
77
- D.data(), I.data(), params);
78
- return I;
79
- }
80
-
81
-
82
-
83
-
84
- /*************************************************************
85
- * Test functions for a given index type
86
- *************************************************************/
87
-
88
- int test_params_override (const char *index_key, MetricType metric) {
89
- std::vector<float> xb = make_data(nb); // database vectors
90
- auto index = make_index(index_key, metric, xb);
91
- //index->train(nb, xb.data());
92
- // index->add(nb, xb.data());
93
- std::vector<float> xq = make_data(nq);
94
- ParameterSpace ps;
95
- ps.set_index_parameter(index.get(), "nprobe", 2);
96
- auto res2ref = search_index(index.get(), xq.data());
97
- ps.set_index_parameter(index.get(), "nprobe", 9);
98
- auto res9ref = search_index(index.get(), xq.data());
99
- ps.set_index_parameter(index.get(), "nprobe", 1);
100
-
101
- IVFSearchParameters params;
102
- params.max_codes = 0;
103
- params.nprobe = 2;
104
- auto res2new = search_index_with_params(index.get(), xq.data(), &params);
105
- params.nprobe = 9;
106
- auto res9new = search_index_with_params(index.get(), xq.data(), &params);
107
-
108
- if (res2ref != res2new)
109
- return 2;
110
-
111
- if (res9ref != res9new)
112
- return 9;
113
-
114
- return 0;
115
- }
116
-
117
-
118
- } // namespace
119
-
120
-
121
- /*************************************************************
122
- * Test entry points
123
- *************************************************************/
124
-
125
- TEST(TPO, IVFFlat) {
126
- int err1 = test_params_override ("IVF32,Flat", METRIC_L2);
127
- EXPECT_EQ(err1, 0);
128
- int err2 = test_params_override ("IVF32,Flat", METRIC_INNER_PRODUCT);
129
- EXPECT_EQ(err2, 0);
130
- }
131
-
132
- TEST(TPO, IVFPQ) {
133
- int err1 = test_params_override ("IVF32,PQ8np", METRIC_L2);
134
- EXPECT_EQ(err1, 0);
135
- int err2 = test_params_override ("IVF32,PQ8np", METRIC_INNER_PRODUCT);
136
- EXPECT_EQ(err2, 0);
137
- }
138
-
139
- TEST(TPO, IVFSQ) {
140
- int err1 = test_params_override ("IVF32,SQ8", METRIC_L2);
141
- EXPECT_EQ(err1, 0);
142
- int err2 = test_params_override ("IVF32,SQ8", METRIC_INNER_PRODUCT);
143
- EXPECT_EQ(err2, 0);
144
- }
145
-
146
- TEST(TPO, IVFFlatPP) {
147
- int err1 = test_params_override ("PCA16,IVF32,SQ8", METRIC_L2);
148
- EXPECT_EQ(err1, 0);
149
- int err2 = test_params_override ("PCA16,IVF32,SQ8", METRIC_INNER_PRODUCT);
150
- EXPECT_EQ(err2, 0);
151
- }
152
-
153
-
154
-
155
- /*************************************************************
156
- * Same for binary indexes
157
- *************************************************************/
158
-
159
-
160
- std::vector<uint8_t> make_data_binary(size_t n) {
161
- std::vector <uint8_t> database (n * d / 8);
162
- std::uniform_int_distribution<> distrib;
163
- for (size_t i = 0; i < n * d / 8; i++) {
164
- database[i] = distrib(rng);
165
- }
166
- return database;
167
- }
168
-
169
- std::unique_ptr<IndexBinaryIVF> make_index(const char *index_type,
170
- const std::vector<uint8_t> & x)
171
- {
172
-
173
- auto index = std::unique_ptr<IndexBinaryIVF>
174
- (dynamic_cast<IndexBinaryIVF*>(index_binary_factory (d, index_type)));
175
- index->train(nb, x.data());
176
- index->add(nb, x.data());
177
- return index;
178
- }
179
-
180
- std::vector<idx_t> search_index(IndexBinaryIVF *index, const uint8_t *xq) {
181
- int k = 10;
182
- std::vector<idx_t> I(k * nq);
183
- std::vector<int32_t> D(k * nq);
184
- index->search (nq, xq, k, D.data(), I.data());
185
- return I;
186
- }
187
-
188
- std::vector<idx_t> search_index_with_params(
189
- IndexBinaryIVF *index, const uint8_t *xq, IVFSearchParameters *params) {
190
- int k = 10;
191
- std::vector<idx_t> I(k * nq);
192
- std::vector<int32_t> D(k * nq);
193
-
194
- std::vector<idx_t> Iq(params->nprobe * nq);
195
- std::vector<int32_t> Dq(params->nprobe * nq);
196
-
197
- index->quantizer->search(nq, xq, params->nprobe,
198
- Dq.data(), Iq.data());
199
- index->search_preassigned(nq, xq, k, Iq.data(), Dq.data(),
200
- D.data(), I.data(),
201
- false, params);
202
- return I;
203
- }
204
-
205
- int test_params_override_binary (const char *index_key) {
206
- std::vector<uint8_t> xb = make_data_binary(nb); // database vectors
207
- auto index = make_index (index_key, xb);
208
- index->train(nb, xb.data());
209
- index->add(nb, xb.data());
210
- std::vector<uint8_t> xq = make_data_binary(nq);
211
- index->nprobe = 2;
212
- auto res2ref = search_index(index.get(), xq.data());
213
- index->nprobe = 9;
214
- auto res9ref = search_index(index.get(), xq.data());
215
- index->nprobe = 1;
216
-
217
- IVFSearchParameters params;
218
- params.max_codes = 0;
219
- params.nprobe = 2;
220
- auto res2new = search_index_with_params(index.get(), xq.data(), &params);
221
- params.nprobe = 9;
222
- auto res9new = search_index_with_params(index.get(), xq.data(), &params);
223
-
224
- if (res2ref != res2new)
225
- return 2;
226
-
227
- if (res9ref != res9new)
228
- return 9;
229
-
230
- return 0;
231
- }
232
-
233
- TEST(TPOB, IVF) {
234
- int err1 = test_params_override_binary ("BIVF32");
235
- EXPECT_EQ(err1, 0);
236
- }
@@ -1,98 +0,0 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
3
- *
4
- * This source code is licensed under the MIT license found in the
5
- * LICENSE file in the root directory of this source tree.
6
- */
7
-
8
-
9
- #include <iostream>
10
- #include <vector>
11
- #include <memory>
12
-
13
- #include <gtest/gtest.h>
14
-
15
- #include <faiss/impl/ProductQuantizer.h>
16
-
17
-
18
- namespace {
19
-
20
- const std::vector<uint64_t> random_vector(size_t s) {
21
- std::vector<uint64_t> v(s, 0);
22
- for (size_t i = 0; i < s; ++i) {
23
- v[i] = rand();
24
- }
25
-
26
- return v;
27
- }
28
-
29
- } // namespace
30
-
31
-
32
- TEST(PQEncoderGeneric, encode) {
33
- const int nsubcodes = 97;
34
- const int minbits = 1;
35
- const int maxbits = 24;
36
- const std::vector<uint64_t> values = random_vector(nsubcodes);
37
-
38
- for(int nbits = minbits; nbits <= maxbits; ++nbits) {
39
- std::cerr << "nbits = " << nbits << std::endl;
40
-
41
- const uint64_t mask = (1ull << nbits) - 1;
42
- std::unique_ptr<uint8_t[]> codes(
43
- new uint8_t[(nsubcodes * maxbits + 7) / 8]
44
- );
45
-
46
- // NOTE(hoss): Necessary scope to ensure trailing bits are flushed to mem.
47
- {
48
- faiss::PQEncoderGeneric encoder(codes.get(), nbits);
49
- for (const auto& v : values) {
50
- encoder.encode(v & mask);
51
- }
52
- }
53
-
54
- faiss::PQDecoderGeneric decoder(codes.get(), nbits);
55
- for (int i = 0; i < nsubcodes; ++i) {
56
- uint64_t v = decoder.decode();
57
- EXPECT_EQ(values[i] & mask, v);
58
- }
59
- }
60
- }
61
-
62
-
63
- TEST(PQEncoder8, encode) {
64
- const int nsubcodes = 100;
65
- const std::vector<uint64_t> values = random_vector(nsubcodes);
66
- const uint64_t mask = 0xFF;
67
- std::unique_ptr<uint8_t[]> codes(new uint8_t[nsubcodes]);
68
-
69
- faiss::PQEncoder8 encoder(codes.get(), 8);
70
- for (const auto& v : values) {
71
- encoder.encode(v & mask);
72
- }
73
-
74
- faiss::PQDecoder8 decoder(codes.get(), 8);
75
- for (int i = 0; i < nsubcodes; ++i) {
76
- uint64_t v = decoder.decode();
77
- EXPECT_EQ(values[i] & mask, v);
78
- }
79
- }
80
-
81
-
82
- TEST(PQEncoder16, encode) {
83
- const int nsubcodes = 100;
84
- const std::vector<uint64_t> values = random_vector(nsubcodes);
85
- const uint64_t mask = 0xFFFF;
86
- std::unique_ptr<uint8_t[]> codes(new uint8_t[2 * nsubcodes]);
87
-
88
- faiss::PQEncoder16 encoder(codes.get(), 16);
89
- for (const auto& v : values) {
90
- encoder.encode(v & mask);
91
- }
92
-
93
- faiss::PQDecoder16 decoder(codes.get(), 16);
94
- for (int i = 0; i < nsubcodes; ++i) {
95
- uint64_t v = decoder.decode();
96
- EXPECT_EQ(values[i] & mask, v);
97
- }
98
- }