faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -11,6 +11,7 @@
11
11
  #include <faiss/gpu/impl/IndexUtils.h>
12
12
  #include <faiss/gpu/test/TestUtils.h>
13
13
  #include <faiss/gpu/utils/DeviceUtils.h>
14
+ #include <faiss/utils/random.h>
14
15
  #include <faiss/utils/utils.h>
15
16
  #include <gtest/gtest.h>
16
17
  #include <sstream>
@@ -164,6 +165,30 @@ TEST(TestGpuIndexBinaryFlat, LargeIndex) {
164
165
  compareBinaryDist(cpuDist, cpuLabels, gpuDist, gpuLabels, nq, k);
165
166
  }
166
167
 
168
+ TEST(TestGpuIndexBinaryFlat, Reconstruct) {
169
+ int n = 1000;
170
+ std::vector<uint8_t> xb(8 * n);
171
+ faiss::byte_rand(xb.data(), xb.size(), 123);
172
+ std::unique_ptr<faiss::IndexBinaryFlat> index(
173
+ new faiss::IndexBinaryFlat(64));
174
+ index->add(n, xb.data());
175
+
176
+ std::vector<uint8_t> xb3(8 * n);
177
+ index->reconstruct_n(0, index->ntotal, xb3.data());
178
+ EXPECT_EQ(xb, xb3);
179
+
180
+ faiss::gpu::StandardGpuResources res;
181
+ res.noTempMemory();
182
+
183
+ std::unique_ptr<faiss::gpu::GpuIndexBinaryFlat> index2(
184
+ new faiss::gpu::GpuIndexBinaryFlat(&res, index.get()));
185
+
186
+ std::vector<uint8_t> xb2(8 * n);
187
+
188
+ index2->reconstruct_n(0, index->ntotal, xb2.data());
189
+ EXPECT_EQ(xb2, xb3);
190
+ }
191
+
167
192
  int main(int argc, char** argv) {
168
193
  testing::InitGoogleTest(&argc, argv);
169
194
 
@@ -28,7 +28,8 @@ struct TestFlatOptions {
28
28
  numVecsOverride(-1),
29
29
  numQueriesOverride(-1),
30
30
  kOverride(-1),
31
- dimOverride(-1) {}
31
+ dimOverride(-1),
32
+ use_raft(false) {}
32
33
 
33
34
  faiss::MetricType metric;
34
35
  float metricArg;
@@ -38,6 +39,7 @@ struct TestFlatOptions {
38
39
  int numQueriesOverride;
39
40
  int kOverride;
40
41
  int dimOverride;
42
+ bool use_raft;
41
43
  };
42
44
 
43
45
  void testFlat(const TestFlatOptions& opt) {
@@ -73,6 +75,7 @@ void testFlat(const TestFlatOptions& opt) {
73
75
  faiss::gpu::GpuIndexFlatConfig config;
74
76
  config.device = device;
75
77
  config.useFloat16 = opt.useFloat16;
78
+ config.use_raft = opt.use_raft;
76
79
 
77
80
  faiss::gpu::GpuIndexFlat gpuIndex(&res, dim, opt.metric, config);
78
81
  gpuIndex.metric_arg = opt.metricArg;
@@ -110,6 +113,11 @@ TEST(TestGpuIndexFlat, IP_Float32) {
110
113
  opt.useFloat16 = false;
111
114
 
112
115
  testFlat(opt);
116
+
117
+ #if defined USE_NVIDIA_RAFT
118
+ opt.use_raft = true;
119
+ testFlat(opt);
120
+ #endif
113
121
  }
114
122
  }
115
123
 
@@ -119,6 +127,11 @@ TEST(TestGpuIndexFlat, L1_Float32) {
119
127
  opt.useFloat16 = false;
120
128
 
121
129
  testFlat(opt);
130
+
131
+ #if defined USE_NVIDIA_RAFT
132
+ opt.use_raft = true;
133
+ testFlat(opt);
134
+ #endif
122
135
  }
123
136
 
124
137
  TEST(TestGpuIndexFlat, Lp_Float32) {
@@ -128,6 +141,10 @@ TEST(TestGpuIndexFlat, Lp_Float32) {
128
141
  opt.useFloat16 = false;
129
142
 
130
143
  testFlat(opt);
144
+ #if defined USE_NVIDIA_RAFT
145
+ opt.use_raft = true;
146
+ testFlat(opt);
147
+ #endif
131
148
  }
132
149
 
133
150
  TEST(TestGpuIndexFlat, L2_Float32) {
@@ -138,6 +155,10 @@ TEST(TestGpuIndexFlat, L2_Float32) {
138
155
  opt.useFloat16 = false;
139
156
 
140
157
  testFlat(opt);
158
+ #if defined USE_NVIDIA_RAFT
159
+ opt.use_raft = true;
160
+ testFlat(opt);
161
+ #endif
141
162
  }
142
163
  }
143
164
 
@@ -152,6 +173,10 @@ TEST(TestGpuIndexFlat, L2_k_2048) {
152
173
  opt.numVecsOverride = 10000;
153
174
 
154
175
  testFlat(opt);
176
+ #if defined USE_NVIDIA_RAFT
177
+ opt.use_raft = true;
178
+ testFlat(opt);
179
+ #endif
155
180
  }
156
181
  }
157
182
 
@@ -164,6 +189,10 @@ TEST(TestGpuIndexFlat, L2_Float32_K1) {
164
189
  opt.kOverride = 1;
165
190
 
166
191
  testFlat(opt);
192
+ #if defined USE_NVIDIA_RAFT
193
+ opt.use_raft = true;
194
+ testFlat(opt);
195
+ #endif
167
196
  }
168
197
  }
169
198
 
@@ -174,6 +203,10 @@ TEST(TestGpuIndexFlat, IP_Float16) {
174
203
  opt.useFloat16 = true;
175
204
 
176
205
  testFlat(opt);
206
+ #if defined USE_NVIDIA_RAFT
207
+ opt.use_raft = true;
208
+ testFlat(opt);
209
+ #endif
177
210
  }
178
211
  }
179
212
 
@@ -184,6 +217,10 @@ TEST(TestGpuIndexFlat, L2_Float16) {
184
217
  opt.useFloat16 = true;
185
218
 
186
219
  testFlat(opt);
220
+ #if defined USE_NVIDIA_RAFT
221
+ opt.use_raft = true;
222
+ testFlat(opt);
223
+ #endif
187
224
  }
188
225
  }
189
226
 
@@ -196,6 +233,10 @@ TEST(TestGpuIndexFlat, L2_Float16_K1) {
196
233
  opt.kOverride = 1;
197
234
 
198
235
  testFlat(opt);
236
+ #if defined USE_NVIDIA_RAFT
237
+ opt.use_raft = true;
238
+ testFlat(opt);
239
+ #endif
199
240
  }
200
241
  }
201
242
 
@@ -213,6 +254,10 @@ TEST(TestGpuIndexFlat, L2_Tiling) {
213
254
  opt.kOverride = 64;
214
255
 
215
256
  testFlat(opt);
257
+ #if defined USE_NVIDIA_RAFT
258
+ opt.use_raft = true;
259
+ testFlat(opt);
260
+ #endif
216
261
  }
217
262
  }
218
263
 
@@ -223,7 +268,7 @@ TEST(TestGpuIndexFlat, QueryEmpty) {
223
268
  faiss::gpu::GpuIndexFlatConfig config;
224
269
  config.device = 0;
225
270
  config.useFloat16 = false;
226
-
271
+ config.use_raft = false;
227
272
  int dim = 128;
228
273
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config);
229
274
 
@@ -247,7 +292,7 @@ TEST(TestGpuIndexFlat, QueryEmpty) {
247
292
  }
248
293
  }
249
294
 
250
- TEST(TestGpuIndexFlat, CopyFrom) {
295
+ void testCopyFrom(bool use_raft) {
251
296
  int numVecs = faiss::gpu::randVal(100, 200);
252
297
  int dim = faiss::gpu::randVal(1, 1000);
253
298
 
@@ -265,6 +310,7 @@ TEST(TestGpuIndexFlat, CopyFrom) {
265
310
  faiss::gpu::GpuIndexFlatConfig config;
266
311
  config.device = device;
267
312
  config.useFloat16 = useFloat16;
313
+ config.use_raft = use_raft;
268
314
 
269
315
  // Fill with garbage values
270
316
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, 2000, config);
@@ -293,7 +339,17 @@ TEST(TestGpuIndexFlat, CopyFrom) {
293
339
  }
294
340
  }
295
341
 
296
- TEST(TestGpuIndexFlat, CopyTo) {
342
+ TEST(TestGpuIndexFlat, CopyFrom) {
343
+ testCopyFrom(false);
344
+ }
345
+
346
+ #if defined USE_NVIDIA_RAFT
347
+ TEST(TestRaftGpuIndexFlat, CopyFrom) {
348
+ testCopyFrom(true);
349
+ }
350
+ #endif
351
+
352
+ void testCopyTo(bool use_raft) {
297
353
  faiss::gpu::StandardGpuResources res;
298
354
  res.noTempMemory();
299
355
 
@@ -307,6 +363,7 @@ TEST(TestGpuIndexFlat, CopyTo) {
307
363
  faiss::gpu::GpuIndexFlatConfig config;
308
364
  config.device = device;
309
365
  config.useFloat16 = useFloat16;
366
+ config.use_raft = use_raft;
310
367
 
311
368
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config);
312
369
  gpuIndex.add(numVecs, vecs.data());
@@ -333,7 +390,17 @@ TEST(TestGpuIndexFlat, CopyTo) {
333
390
  }
334
391
  }
335
392
 
336
- TEST(TestGpuIndexFlat, UnifiedMemory) {
393
+ TEST(TestGpuIndexFlat, CopyTo) {
394
+ testCopyTo(false);
395
+ }
396
+
397
+ #if defined USE_NVIDIA_RAFT
398
+ TEST(TestRaftGpuIndexFlat, CopyTo) {
399
+ testCopyTo(true);
400
+ }
401
+ #endif
402
+
403
+ void testUnifiedMemory(bool use_raft) {
337
404
  // Construct on a random device to test multi-device, if we have
338
405
  // multiple devices
339
406
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -359,6 +426,7 @@ TEST(TestGpuIndexFlat, UnifiedMemory) {
359
426
  faiss::gpu::GpuIndexFlatConfig config;
360
427
  config.device = device;
361
428
  config.memorySpace = faiss::gpu::MemorySpace::Unified;
429
+ config.use_raft = use_raft;
362
430
 
363
431
  faiss::gpu::GpuIndexFlatL2 gpuIndexL2(&res, dim, config);
364
432
 
@@ -380,7 +448,17 @@ TEST(TestGpuIndexFlat, UnifiedMemory) {
380
448
  0.015f);
381
449
  }
382
450
 
383
- TEST(TestGpuIndexFlat, LargeIndex) {
451
+ TEST(TestGpuIndexFlat, UnifiedMemory) {
452
+ testUnifiedMemory(false);
453
+ }
454
+
455
+ #if defined USE_NVIDIA_RAFT
456
+ TEST(TestRaftGpuIndexFlat, UnifiedMemory) {
457
+ testUnifiedMemory(true);
458
+ }
459
+ #endif
460
+
461
+ void testLargeIndex(bool use_raft) {
384
462
  // Construct on a random device to test multi-device, if we have
385
463
  // multiple devices
386
464
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -411,6 +489,7 @@ TEST(TestGpuIndexFlat, LargeIndex) {
411
489
 
412
490
  faiss::gpu::GpuIndexFlatConfig config;
413
491
  config.device = device;
492
+ config.use_raft = use_raft;
414
493
  faiss::gpu::GpuIndexFlatL2 gpuIndexL2(&res, dim, config);
415
494
 
416
495
  cpuIndexL2.add(nb, xb.data());
@@ -430,7 +509,17 @@ TEST(TestGpuIndexFlat, LargeIndex) {
430
509
  0.015f);
431
510
  }
432
511
 
433
- TEST(TestGpuIndexFlat, Residual) {
512
+ TEST(TestGpuIndexFlat, LargeIndex) {
513
+ testLargeIndex(false);
514
+ }
515
+
516
+ #if defined USE_NVIDIA_RAFT
517
+ TEST(TestRaftGpuIndexFlat, LargeIndex) {
518
+ testLargeIndex(true);
519
+ }
520
+ #endif
521
+
522
+ void testResidual(bool use_raft) {
434
523
  // Construct on a random device to test multi-device, if we have
435
524
  // multiple devices
436
525
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -440,6 +529,7 @@ TEST(TestGpuIndexFlat, Residual) {
440
529
 
441
530
  faiss::gpu::GpuIndexFlatConfig config;
442
531
  config.device = device;
532
+ config.use_raft = use_raft;
443
533
 
444
534
  int dim = 32;
445
535
  faiss::IndexFlat cpuIndex(dim, faiss::MetricType::METRIC_L2);
@@ -472,7 +562,17 @@ TEST(TestGpuIndexFlat, Residual) {
472
562
  EXPECT_EQ(residualsCpu, residualsGpu);
473
563
  }
474
564
 
475
- TEST(TestGpuIndexFlat, Reconstruct) {
565
+ TEST(TestGpuIndexFlat, Residual) {
566
+ testResidual(false);
567
+ }
568
+
569
+ #if defined USE_NVIDIA_RAFT
570
+ TEST(TestRaftGpuIndexFlat, Residual) {
571
+ testResidual(true);
572
+ }
573
+ #endif
574
+
575
+ void testReconstruct(bool use_raft) {
476
576
  // Construct on a random device to test multi-device, if we have
477
577
  // multiple devices
478
578
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -489,6 +589,7 @@ TEST(TestGpuIndexFlat, Reconstruct) {
489
589
  faiss::gpu::GpuIndexFlatConfig config;
490
590
  config.device = device;
491
591
  config.useFloat16 = useFloat16;
592
+ config.use_raft = use_raft;
492
593
 
493
594
  faiss::gpu::GpuIndexFlat gpuIndex(
494
595
  &res, dim, faiss::MetricType::METRIC_L2, config);
@@ -553,7 +654,16 @@ TEST(TestGpuIndexFlat, Reconstruct) {
553
654
  }
554
655
  }
555
656
 
556
- TEST(TestGpuIndexFlat, SearchAndReconstruct) {
657
+ TEST(TestGpuIndexFlat, Reconstruct) {
658
+ testReconstruct(false);
659
+ }
660
+ #if defined USE_NVIDIA_RAFT
661
+ TEST(TestRaftGpuIndexFlat, Reconstruct) {
662
+ testReconstruct(true);
663
+ }
664
+ #endif
665
+
666
+ void testSearchAndReconstruct(bool use_raft) {
557
667
  // Construct on a random device to test multi-device, if we have
558
668
  // multiple devices
559
669
  int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
@@ -573,6 +683,7 @@ TEST(TestGpuIndexFlat, SearchAndReconstruct) {
573
683
 
574
684
  faiss::gpu::GpuIndexFlatConfig config;
575
685
  config.device = device;
686
+ config.use_raft = use_raft;
576
687
  faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config);
577
688
 
578
689
  cpuIndex.add(nb, xb.data());
@@ -639,6 +750,15 @@ TEST(TestGpuIndexFlat, SearchAndReconstruct) {
639
750
  }
640
751
  }
641
752
  }
753
+ TEST(TestGpuIndexFlat, SearchAndReconstruct) {
754
+ testSearchAndReconstruct(false);
755
+ }
756
+
757
+ #if defined USE_NVIDIA_RAFT
758
+ TEST(TestRaftGpuIndexFlat, SearchAndReconstruct) {
759
+ testSearchAndReconstruct(true);
760
+ }
761
+ #endif
642
762
 
643
763
  int main(int argc, char** argv) {
644
764
  testing::InitGoogleTest(&argc, argv);