faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -11,6 +11,7 @@
11
11
  #include <faiss/gpu/impl/IndexUtils.h>
12
12
  #include <faiss/gpu/test/TestUtils.h>
13
13
  #include <faiss/gpu/utils/DeviceUtils.h>
14
+ #include <faiss/utils/random.h>
14
15
  #include <faiss/utils/utils.h>
15
16
  #include <gtest/gtest.h>
16
17
  #include <sstream>
@@ -164,6 +165,30 @@ TEST(TestGpuIndexBinaryFlat, LargeIndex) {
164
165
  compareBinaryDist(cpuDist, cpuLabels, gpuDist, gpuLabels, nq, k);
165
166
  }
166
167
 
168
+ TEST(TestGpuIndexBinaryFlat, Reconstruct) {
169
+ int n = 1000;
170
+ std::vector<uint8_t> xb(8 * n);
171
+ faiss::byte_rand(xb.data(), xb.size(), 123);
172
+ std::unique_ptr<faiss::IndexBinaryFlat> index(
173
+ new faiss::IndexBinaryFlat(64));
174
+ index->add(n, xb.data());
175
+
176
+ std::vector<uint8_t> xb3(8 * n);
177
+ index->reconstruct_n(0, index->ntotal, xb3.data());
178
+ EXPECT_EQ(xb, xb3);
179
+
180
+ faiss::gpu::StandardGpuResources res;
181
+ res.noTempMemory();
182
+
183
+ std::unique_ptr<faiss::gpu::GpuIndexBinaryFlat> index2(
184
+ new faiss::gpu::GpuIndexBinaryFlat(&res, index.get()));
185
+
186
+ std::vector<uint8_t> xb2(8 * n);
187
+
188
+ index2->reconstruct_n(0, index->ntotal, xb2.data());
189
+ EXPECT_EQ(xb2, xb3);
190
+ }
191
+
167
192
  int main(int argc, char** argv) {
168
193
  testing::InitGoogleTest(&argc, argv);
169
194
 
@@ -28,7 +28,8 @@ struct TestFlatOptions {
28
28
  numVecsOverride(-1),
29
29
  numQueriesOverride(-1),
30
30
  kOverride(-1),
31
- dimOverride(-1) {}
31
+ dimOverride(-1),
32
+ use_raft(false) {}
32
33
 
33
34
  faiss::MetricType metric;
34
35
  float metricArg;
@@ -38,6 +39,7 @@ struct TestFlatOptions {
38
39
  int numQueriesOverride;
39
40
  int kOverride;
40
41
  int dimOverride;
42
+ bool use_raft;
41
43
  };
42
44
 
43
45
  void testFlat(const TestFlatOptions& opt) {
@@ -73,6 +75,7 @@ void testFlat(const TestFlatOptions& opt) {
73
75
  faiss::gpu::GpuIndexFlatConfig config;
74
76
  config.device = device;
75
77
  config.useFloat16 = opt.useFloat16;
78
+ config.use_raft = opt.use_raft;
76
79
 
77
80
  faiss::gpu::GpuIndexFlat gpuIndex(&res, dim, opt.metric, config);
78
81
  gpuIndex.metric_arg = opt.metricArg;
@@ -110,6 +113,11 @@ TEST(TestGpuIndexFlat, IP_Float32) {
110
113
  opt.useFloat16 = false;
111
114
 
112
115
  testFlat(opt);
116
+
117
+ #if defined USE_NVIDIA_RAFT
118
+ opt.use_raft = true;
119
+ testFlat(opt);
120
+ #endif
113
121
  }
114
122
  }
115
123
 
@@ -119,6 +127,11 @@ TEST(TestGpuIndexFlat, L1_Float32) {
119
127
  opt.useFloat16 = false;
120
128
 
121
129
  testFlat(opt);
130
+
131
+ #if defined USE_NVIDIA_RAFT
132
+ opt.use_raft = true;
133
+ testFlat(opt);
134
+ #endif
122
135
  }
123
136
 
124
137
  TEST(TestGpuIndexFlat, Lp_Float32) {
@@ -128,6 +141,10 @@ TEST(TestGpuIndexFlat, Lp_Float32) {
128
141
  opt.useFloat16 = false;
129
142
 
130
143
  testFlat(opt);
144
+ #if defined USE_NVIDIA_RAFT
145
+ opt.use_raft = true;
146
+ testFlat(opt);
147
+ #endif
131
148
  }
132
149
 
133
150
  TEST(TestGpuIndexFlat, L2_Float32) {
@@ -138,6 +155,10 @@ TEST(TestGpuIndexFlat, L2_Float32) {
138
155
  opt.useFloat16 = false;
139
156
 
140
157
  testFlat(opt);
158
+ #if defined USE_NVIDIA_RAFT
159
+ opt.use_raft = true;
160
+ testFlat(opt);
161
+ #endif
141
162
  }
142
163
  }
143
164
 
@@ -152,6 +173,10 @@ TEST(TestGpuIndexFlat, L2_k_2048) {
152
173
  opt.numVecsOverride = 10000;
153
174
 
154
175
  testFlat(opt);
176
+ #if defined USE_NVIDIA_RAFT
177
+ opt.use_raft = true;
178
+ testFlat(opt);
179
+ #endif
155
180
  }
156
181
  }
157
182
 
@@ -164,6 +189,10 @@ TEST(TestGpuIndexFlat, L2_Float32_K1) {
164
189
  opt.kOverride = 1;
165
190
 
166
191
  testFlat(opt);
192
+ #if defined USE_NVIDIA_RAFT
193
+ opt.use_raft = true;
194
+ testFlat(opt);
195
+ #endif
167
196
  }
168
197
  }
169
198
 
@@ -174,6 +203,10 @@ TEST(TestGpuIndexFlat, IP_Float16) {
174
203
  opt.useFloat16 = true;
175
204
 
176
205
  testFlat(opt);
206
+ #if defined USE_NVIDIA_RAFT
207
+ opt.use_raft = true;
208
+ testFlat(opt);
209
+ #endif
177
210
  }
178
211
  }
179
212
 
@@ -184,6 +217,10 @@ TEST(TestGpuIndexFlat, L2_Float16) {
184
217
  opt.useFloat16 = true;
185
218
 
186
219
  testFlat(opt);
220
+ #if defined USE_NVIDIA_RAFT
221
+ opt.use_raft = true;
222
+ testFlat(opt);
223
+ #endif
187
224
  }
188
225
  }
189
226
 
@@ -196,6 +233,10 @@ TEST(TestGpuIndexFlat, L2_Float16_K1) {
196
233
  opt.kOverride = 1;
197
234
 
198
235
  testFlat(opt);
236
+ #if defined USE_NVIDIA_RAFT
237
+ opt.use_raft = true;
238
+ testFlat(opt);
239
+ #endif
199
240
  }
200
241
  }
201
242
 
@@ -213,6 +254,10 @@ TEST(TestGpuIndexFlat, L2_Tiling) {
213
254
  opt.kOverride = 64;
214
255
 
215
256
  testFlat(opt);
257
+ #if defined USE_NVIDIA_RAFT
258
+ opt.use_raft = true;
259
+ testFlat(opt);
260
+ #endif
216
261
  }
217
262
  }
218
263
 
@@ -223,7 +268,7 @@ TEST(TestGpuIndexFlat, QueryEmpty) {
223
268
  faiss::gpu::GpuIndexFlatConfig config;
224
269
  config.device = 0;
225
270
  config.useFloat16 = false;
226
-
271
+ config.use_raft = false;
227
272
  int dim = 128;
228
273
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config);
229
274
 
@@ -247,7 +292,7 @@ TEST(TestGpuIndexFlat, QueryEmpty) {
247
292
  }
248
293
  }
249
294
 
250
- TEST(TestGpuIndexFlat, CopyFrom) {
295
+ void testCopyFrom(bool use_raft) {
251
296
  int numVecs = faiss::gpu::randVal(100, 200);
252
297
  int dim = faiss::gpu::randVal(1, 1000);
253
298
 
@@ -265,6 +310,7 @@ TEST(TestGpuIndexFlat, CopyFrom) {
265
310
  faiss::gpu::GpuIndexFlatConfig config;
266
311
  config.device = device;
267
312
  config.useFloat16 = useFloat16;
313
+ config.use_raft = use_raft;
268
314
 
269
315
  // Fill with garbage values
270
316
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, 2000, config);
@@ -293,7 +339,17 @@ TEST(TestGpuIndexFlat, CopyFrom) {
293
339
  }
294
340
  }
295
341
 
296
- TEST(TestGpuIndexFlat, CopyTo) {
342
+ TEST(TestGpuIndexFlat, CopyFrom) {
343
+ testCopyFrom(false);
344
+ }
345
+
346
+ #if defined USE_NVIDIA_RAFT
347
+ TEST(TestRaftGpuIndexFlat, CopyFrom) {
348
+ testCopyFrom(true);
349
+ }
350
+ #endif
351
+
352
+ void testCopyTo(bool use_raft) {
297
353
  faiss::gpu::StandardGpuResources res;
298
354
  res.noTempMemory();
299
355
 
@@ -307,6 +363,7 @@ TEST(TestGpuIndexFlat, CopyTo) {
307
363
  faiss::gpu::GpuIndexFlatConfig config;
308
364
  config.device = device;
309
365
  config.useFloat16 = useFloat16;
366
+ config.use_raft = use_raft;
310
367
 
311
368
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config);
312
369
  gpuIndex.add(numVecs, vecs.data());
@@ -333,7 +390,17 @@ TEST(TestGpuIndexFlat, CopyTo) {
333
390
  }
334
391
  }
335
392
 
336
- TEST(TestGpuIndexFlat, UnifiedMemory) {
393
+ TEST(TestGpuIndexFlat, CopyTo) {
394
+ testCopyTo(false);
395
+ }
396
+
397
+ #if defined USE_NVIDIA_RAFT
398
+ TEST(TestRaftGpuIndexFlat, CopyTo) {
399
+ testCopyTo(true);
400
+ }
401
+ #endif
402
+
403
+ void testUnifiedMemory(bool use_raft) {
337
404
  // Construct on a random device to test multi-device, if we have
338
405
  // multiple devices
339
406
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -359,6 +426,7 @@ TEST(TestGpuIndexFlat, UnifiedMemory) {
359
426
  faiss::gpu::GpuIndexFlatConfig config;
360
427
  config.device = device;
361
428
  config.memorySpace = faiss::gpu::MemorySpace::Unified;
429
+ config.use_raft = use_raft;
362
430
 
363
431
  faiss::gpu::GpuIndexFlatL2 gpuIndexL2(&res, dim, config);
364
432
 
@@ -380,7 +448,17 @@ TEST(TestGpuIndexFlat, UnifiedMemory) {
380
448
  0.015f);
381
449
  }
382
450
 
383
- TEST(TestGpuIndexFlat, LargeIndex) {
451
+ TEST(TestGpuIndexFlat, UnifiedMemory) {
452
+ testUnifiedMemory(false);
453
+ }
454
+
455
+ #if defined USE_NVIDIA_RAFT
456
+ TEST(TestRaftGpuIndexFlat, UnifiedMemory) {
457
+ testUnifiedMemory(true);
458
+ }
459
+ #endif
460
+
461
+ void testLargeIndex(bool use_raft) {
384
462
  // Construct on a random device to test multi-device, if we have
385
463
  // multiple devices
386
464
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -411,6 +489,7 @@ TEST(TestGpuIndexFlat, LargeIndex) {
411
489
 
412
490
  faiss::gpu::GpuIndexFlatConfig config;
413
491
  config.device = device;
492
+ config.use_raft = use_raft;
414
493
  faiss::gpu::GpuIndexFlatL2 gpuIndexL2(&res, dim, config);
415
494
 
416
495
  cpuIndexL2.add(nb, xb.data());
@@ -430,7 +509,17 @@ TEST(TestGpuIndexFlat, LargeIndex) {
430
509
  0.015f);
431
510
  }
432
511
 
433
- TEST(TestGpuIndexFlat, Residual) {
512
+ TEST(TestGpuIndexFlat, LargeIndex) {
513
+ testLargeIndex(false);
514
+ }
515
+
516
+ #if defined USE_NVIDIA_RAFT
517
+ TEST(TestRaftGpuIndexFlat, LargeIndex) {
518
+ testLargeIndex(true);
519
+ }
520
+ #endif
521
+
522
+ void testResidual(bool use_raft) {
434
523
  // Construct on a random device to test multi-device, if we have
435
524
  // multiple devices
436
525
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -440,6 +529,7 @@ TEST(TestGpuIndexFlat, Residual) {
440
529
 
441
530
  faiss::gpu::GpuIndexFlatConfig config;
442
531
  config.device = device;
532
+ config.use_raft = use_raft;
443
533
 
444
534
  int dim = 32;
445
535
  faiss::IndexFlat cpuIndex(dim, faiss::MetricType::METRIC_L2);
@@ -472,7 +562,17 @@ TEST(TestGpuIndexFlat, Residual) {
472
562
  EXPECT_EQ(residualsCpu, residualsGpu);
473
563
  }
474
564
 
475
- TEST(TestGpuIndexFlat, Reconstruct) {
565
+ TEST(TestGpuIndexFlat, Residual) {
566
+ testResidual(false);
567
+ }
568
+
569
+ #if defined USE_NVIDIA_RAFT
570
+ TEST(TestRaftGpuIndexFlat, Residual) {
571
+ testResidual(true);
572
+ }
573
+ #endif
574
+
575
+ void testReconstruct(bool use_raft) {
476
576
  // Construct on a random device to test multi-device, if we have
477
577
  // multiple devices
478
578
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -489,6 +589,7 @@ TEST(TestGpuIndexFlat, Reconstruct) {
489
589
  faiss::gpu::GpuIndexFlatConfig config;
490
590
  config.device = device;
491
591
  config.useFloat16 = useFloat16;
592
+ config.use_raft = use_raft;
492
593
 
493
594
  faiss::gpu::GpuIndexFlat gpuIndex(
494
595
  &res, dim, faiss::MetricType::METRIC_L2, config);
@@ -553,7 +654,16 @@ TEST(TestGpuIndexFlat, Reconstruct) {
553
654
  }
554
655
  }
555
656
 
556
- TEST(TestGpuIndexFlat, SearchAndReconstruct) {
657
+ TEST(TestGpuIndexFlat, Reconstruct) {
658
+ testReconstruct(false);
659
+ }
660
+ #if defined USE_NVIDIA_RAFT
661
+ TEST(TestRaftGpuIndexFlat, Reconstruct) {
662
+ testReconstruct(true);
663
+ }
664
+ #endif
665
+
666
+ void testSearchAndReconstruct(bool use_raft) {
557
667
  // Construct on a random device to test multi-device, if we have
558
668
  // multiple devices
559
669
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -573,6 +683,7 @@ TEST(TestGpuIndexFlat, SearchAndReconstruct) {
573
683
 
574
684
  faiss::gpu::GpuIndexFlatConfig config;
575
685
  config.device = device;
686
+ config.use_raft = use_raft;
576
687
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config);
577
688
 
578
689
  cpuIndex.add(nb, xb.data());
@@ -639,6 +750,15 @@ TEST(TestGpuIndexFlat, SearchAndReconstruct) {
639
750
  }
640
751
  }
641
752
  }
753
+ TEST(TestGpuIndexFlat, SearchAndReconstruct) {
754
+ testSearchAndReconstruct(false);
755
+ }
756
+
757
+ #if defined USE_NVIDIA_RAFT
758
+ TEST(TestRaftGpuIndexFlat, SearchAndReconstruct) {
759
+ testSearchAndReconstruct(true);
760
+ }
761
+ #endif
642
762
 
643
763
  int main(int argc, char** argv) {
644
764
  testing::InitGoogleTest(&argc, argv);