faiss 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,111 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #pragma once
10
+
11
+ #include <faiss/IndexPQ.h>
12
+ #include <faiss/impl/ProductQuantizer.h>
13
+ #include <faiss/utils/AlignedTable.h>
14
+
15
+
16
+ namespace faiss {
17
+
18
+
19
+ /** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
20
+ *
21
+ * The codes are not stored sequentially but grouped in blocks of size bbs.
22
+ * This makes it possible to compute distances quickly with SIMD instructions.
23
+ *
24
+ * Implementations:
25
+ * 12: blocked loop with internal loop on Q with qbs
26
+ * 13: same with reservoir accumulator to store results
27
+ * 14: no qbs with heap accumulator
28
+ * 15: no qbs with reservoir accumulator
29
+ */
30
+
31
+ struct IndexPQFastScan: Index {
32
+ ProductQuantizer pq;
33
+
34
+ // implementation to select
35
+ int implem = 0;
36
+ // skip some parts of the computation (for timing)
37
+ int skip = 0;
38
+
39
+ // size of the kernel
40
+ int bbs; // set at build time
41
+ int qbs = 0; // query block size 0 = use default
42
+
43
+ // packed version of the codes
44
+ size_t ntotal2;
45
+ size_t M2;
46
+
47
+ AlignedTable<uint8_t> codes;
48
+
49
+ // this is for testing purposes only (set when initialized by IndexPQ)
50
+ const uint8_t *orig_codes = nullptr;
51
+
52
+ IndexPQFastScan(
53
+ int d, size_t M, size_t nbits,
54
+ MetricType metric = METRIC_L2,
55
+ int bbs = 32
56
+ );
57
+
58
+ IndexPQFastScan();
59
+
60
+ /// build from an existing IndexPQ
61
+ explicit IndexPQFastScan(const IndexPQ & orig, int bbs = 32);
62
+
63
+ void train (idx_t n, const float *x) override;
64
+ void add (idx_t n, const float *x) override;
65
+ void reset() override ;
66
+ void search(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ idx_t* labels) const override;
72
+
73
+ // called by search function
74
+ void compute_quantized_LUT(
75
+ idx_t n, const float* x,
76
+ uint8_t *lut, float *normalizers) const ;
77
+
78
+ template<bool is_max>
79
+ void search_dispatch_implem(
80
+ idx_t n, const float* x, idx_t k,
81
+ float* distances, idx_t* labels) const;
82
+
83
+ template<class C>
84
+ void search_implem_2(
85
+ idx_t n, const float* x, idx_t k,
86
+ float* distances, idx_t* labels) const;
87
+
88
+
89
+ template<class C>
90
+ void search_implem_12(
91
+ idx_t n, const float* x, idx_t k,
92
+ float* distances, idx_t* labels, int impl) const;
93
+
94
+ template<class C>
95
+ void search_implem_14(
96
+ idx_t n, const float* x, idx_t k,
97
+ float* distances, idx_t* labels, int impl) const;
98
+
99
+ };
100
+
101
+ struct FastScanStats {
102
+ uint64_t t0, t1, t2, t3;
103
+ FastScanStats() {reset();}
104
+ void reset() {
105
+ memset(this, 0, sizeof(*this));
106
+ }
107
+ };
108
+
109
+ FAISS_API extern FastScanStats FastScan_stats;
110
+
111
+ } // namespace faiss
@@ -15,6 +15,7 @@
15
15
  #include <memory>
16
16
 
17
17
  #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/AuxIndexStructures.h>
18
19
 
19
20
  namespace faiss {
20
21
 
@@ -282,6 +283,52 @@ void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
282
283
  }
283
284
  }
284
285
 
286
+ namespace {
287
+
288
+ struct PreTransformDistanceComputer: DistanceComputer {
289
+ const IndexPreTransform *index;
290
+ std::unique_ptr<DistanceComputer> sub_dc;
291
+ std::unique_ptr<const float []> query;
292
+
293
+ explicit PreTransformDistanceComputer(const IndexPreTransform *index):
294
+ index(index),
295
+ sub_dc(index->index->get_distance_computer())
296
+ {}
297
+
298
+ void set_query(const float *x) override {
299
+ const float *xt = index->apply_chain (1, x);
300
+ if (xt == x) {
301
+ sub_dc->set_query (x);
302
+ } else {
303
+ query.reset(xt);
304
+ sub_dc->set_query (xt);
305
+ }
306
+ }
307
+
308
+ float symmetric_dis(idx_t i, idx_t j) override
309
+ {
310
+ return sub_dc->symmetric_dis(i, j);
311
+ }
312
+
313
+ float operator () (idx_t i) override
314
+ {
315
+ return (*sub_dc)(i);
316
+ }
317
+
318
+ };
319
+
320
+
321
+ } // anonymous namespace
322
+
323
+
324
+ DistanceComputer * IndexPreTransform::get_distance_computer() const {
325
+ if (chain.empty()) {
326
+ return index->get_distance_computer();
327
+ } else {
328
+ return new PreTransformDistanceComputer(this);
329
+ }
330
+ }
331
+
285
332
 
286
333
 
287
334
  } // namespace faiss
@@ -77,6 +77,8 @@ struct IndexPreTransform: Index {
77
77
  void reverse_chain (idx_t n, const float* xt, float* x) const;
78
78
 
79
79
 
80
+ DistanceComputer * get_distance_computer() const override;
81
+
80
82
  /* standalone codec interface */
81
83
  size_t sa_code_size () const override;
82
84
  void sa_encode (idx_t n, const float *x,
@@ -0,0 +1,256 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #include <faiss/IndexRefine.h>
10
+
11
+ #include <faiss/utils/distances.h>
12
+ #include <faiss/utils/utils.h>
13
+ #include <faiss/utils/Heap.h>
14
+ #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/impl/AuxIndexStructures.h>
16
+ #include <faiss/IndexFlat.h>
17
+
18
+ namespace faiss {
19
+
20
+
21
+
22
+ /***************************************************
23
+ * IndexRefine
24
+ ***************************************************/
25
+
26
+ IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
27
+ Index (base_index->d, base_index->metric_type),
28
+ base_index (base_index),
29
+ refine_index (refine_index)
30
+ {
31
+ own_fields = own_refine_index = false;
32
+ if (refine_index != nullptr) {
33
+ FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
34
+ FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
35
+ is_trained = base_index->is_trained && refine_index->is_trained;
36
+ FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
37
+ } // other case is useful only to construct an IndexRefineFlat
38
+ ntotal = base_index->ntotal;
39
+ }
40
+
41
+ IndexRefine::IndexRefine ():
42
+ base_index(nullptr), refine_index(nullptr),
43
+ own_fields(false), own_refine_index(false)
44
+ {
45
+ }
46
+
47
+ void IndexRefine::train (idx_t n, const float *x)
48
+ {
49
+ base_index->train (n, x);
50
+ refine_index->train (n, x);
51
+ is_trained = true;
52
+ }
53
+
54
+ void IndexRefine::add (idx_t n, const float *x) {
55
+ FAISS_THROW_IF_NOT (is_trained);
56
+ base_index->add (n, x);
57
+ refine_index->add (n, x);
58
+ ntotal = refine_index->ntotal;
59
+ }
60
+
61
+ void IndexRefine::reset ()
62
+ {
63
+ base_index->reset ();
64
+ refine_index->reset ();
65
+ ntotal = 0;
66
+ }
67
+
68
+ namespace {
69
+
70
+ typedef faiss::Index::idx_t idx_t;
71
+
72
+ template<class C>
73
+ static void reorder_2_heaps (
74
+ idx_t n,
75
+ idx_t k, idx_t *labels, float *distances,
76
+ idx_t k_base, const idx_t *base_labels, const float *base_distances)
77
+ {
78
+ #pragma omp parallel for
79
+ for (idx_t i = 0; i < n; i++) {
80
+ idx_t *idxo = labels + i * k;
81
+ float *diso = distances + i * k;
82
+ const idx_t *idxi = base_labels + i * k_base;
83
+ const float *disi = base_distances + i * k_base;
84
+
85
+ heap_heapify<C> (k, diso, idxo, disi, idxi, k);
86
+ if (k_base != k) { // add remaining elements
87
+ heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
88
+ }
89
+ heap_reorder<C> (k, diso, idxo);
90
+ }
91
+ }
92
+
93
+
94
+ } // anonymous namespace
95
+
96
+
97
+
98
+ void IndexRefine::search (
99
+ idx_t n, const float *x, idx_t k,
100
+ float *distances, idx_t *labels) const
101
+ {
102
+ FAISS_THROW_IF_NOT (is_trained);
103
+ idx_t k_base = idx_t (k * k_factor);
104
+ idx_t * base_labels = labels;
105
+ float * base_distances = distances;
106
+ ScopeDeleter<idx_t> del1;
107
+ ScopeDeleter<float> del2;
108
+
109
+ if (k != k_base) {
110
+ base_labels = new idx_t [n * k_base];
111
+ del1.set (base_labels);
112
+ base_distances = new float [n * k_base];
113
+ del2.set (base_distances);
114
+ }
115
+
116
+ base_index->search (n, x, k_base, base_distances, base_labels);
117
+
118
+ for (int i = 0; i < n * k_base; i++)
119
+ assert (base_labels[i] >= -1 &&
120
+ base_labels[i] < ntotal);
121
+
122
+ // parallelize over queries
123
+ #pragma omp parallel if (n > 1)
124
+ {
125
+ std::unique_ptr<DistanceComputer> dc(
126
+ refine_index->get_distance_computer()
127
+ );
128
+ #pragma omp for
129
+ for (idx_t i = 0; i < n; i++) {
130
+ dc->set_query(x + i * d);
131
+ idx_t ij = i * k_base;
132
+ for (idx_t j = 0; j < k_base; j++) {
133
+ idx_t idx = base_labels[ij];
134
+ if (idx < 0) break;
135
+ base_distances[ij] = (*dc)(idx);
136
+ ij++;
137
+ }
138
+ }
139
+ }
140
+
141
+ // sort and store result
142
+ if (metric_type == METRIC_L2) {
143
+ typedef CMax <float, idx_t> C;
144
+ reorder_2_heaps<C> (
145
+ n, k, labels, distances,
146
+ k_base, base_labels, base_distances);
147
+
148
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
149
+ typedef CMin <float, idx_t> C;
150
+ reorder_2_heaps<C> (
151
+ n, k, labels, distances,
152
+ k_base, base_labels, base_distances);
153
+ } else {
154
+ FAISS_THROW_MSG("Metric type not supported");
155
+ }
156
+
157
+ }
158
+
159
+ void IndexRefine::reconstruct (idx_t key, float * recons) const {
160
+ refine_index->reconstruct (key, recons);
161
+ }
162
+
163
+
164
+
165
+
166
+ IndexRefine::~IndexRefine ()
167
+ {
168
+ if (own_fields) delete base_index;
169
+ if (own_refine_index) delete refine_index;
170
+ }
171
+
172
+
173
+ /***************************************************
174
+ * IndexRefineFlat
175
+ ***************************************************/
176
+
177
+ IndexRefineFlat::IndexRefineFlat (Index *base_index):
178
+ IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
179
+ {
180
+ is_trained = base_index->is_trained;
181
+ own_refine_index = true;
182
+ FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
183
+ "base_index should be empty in the beginning");
184
+ }
185
+
186
+
187
+ IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
188
+ IndexRefine (base_index, nullptr)
189
+ {
190
+ is_trained = base_index->is_trained;
191
+ refine_index = new IndexFlat(base_index->d, base_index->metric_type);
192
+ own_refine_index = true;
193
+ refine_index->add (base_index->ntotal, xb);
194
+
195
+ }
196
+
197
+ IndexRefineFlat::IndexRefineFlat():
198
+ IndexRefine()
199
+ {
200
+ own_refine_index = true;
201
+ }
202
+
203
+
204
+ void IndexRefineFlat::search (
205
+ idx_t n, const float *x, idx_t k,
206
+ float *distances, idx_t *labels) const
207
+ {
208
+ FAISS_THROW_IF_NOT (is_trained);
209
+ idx_t k_base = idx_t (k * k_factor);
210
+ idx_t * base_labels = labels;
211
+ float * base_distances = distances;
212
+ ScopeDeleter<idx_t> del1;
213
+ ScopeDeleter<float> del2;
214
+
215
+ if (k != k_base) {
216
+ base_labels = new idx_t [n * k_base];
217
+ del1.set (base_labels);
218
+ base_distances = new float [n * k_base];
219
+ del2.set (base_distances);
220
+ }
221
+
222
+ base_index->search (n, x, k_base, base_distances, base_labels);
223
+
224
+ for (int i = 0; i < n * k_base; i++)
225
+ assert (base_labels[i] >= -1 &&
226
+ base_labels[i] < ntotal);
227
+
228
+ // compute refined distances
229
+ auto rf = dynamic_cast<const IndexFlat *>(refine_index);
230
+ FAISS_THROW_IF_NOT(rf);
231
+
232
+ rf->compute_distance_subset (
233
+ n, x, k_base, base_distances, base_labels);
234
+
235
+ // sort and store result
236
+ if (metric_type == METRIC_L2) {
237
+ typedef CMax <float, idx_t> C;
238
+ reorder_2_heaps<C> (
239
+ n, k, labels, distances,
240
+ k_base, base_labels, base_distances);
241
+
242
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
243
+ typedef CMin <float, idx_t> C;
244
+ reorder_2_heaps<C> (
245
+ n, k, labels, distances,
246
+ k_base, base_labels, base_distances);
247
+ } else {
248
+ FAISS_THROW_MSG("Metric type not supported");
249
+ }
250
+
251
+ }
252
+
253
+
254
+
255
+
256
+ } // namespace faiss
@@ -0,0 +1,73 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+
12
+
13
+ namespace faiss {
14
+
15
+
16
+ /** Index that queries in a base_index (a fast one) and refines the
17
+ * results with an exact search, hopefully improving the results.
18
+ */
19
+ struct IndexRefine: Index {
20
+
21
+ /// faster index to pre-select the vectors that should be filtered
22
+ Index *base_index;
23
+
24
+ /// refinement index
25
+ Index *refine_index;
26
+
27
+ bool own_fields; ///< should the base index be deallocated?
28
+ bool own_refine_index; ///< same with the refinement index
29
+
30
+ /// factor between k requested in search and the k requested from
31
+ /// the base_index (should be >= 1)
32
+ float k_factor = 1;
33
+
34
+ /// intitialize from empty index
35
+ IndexRefine (Index *base_index, Index *refine_index);
36
+
37
+ IndexRefine ();
38
+
39
+ void train(idx_t n, const float* x) override;
40
+
41
+ void add(idx_t n, const float* x) override;
42
+
43
+ void reset() override;
44
+
45
+ void search(
46
+ idx_t n, const float* x, idx_t k,
47
+ float* distances, idx_t* labels) const override;
48
+
49
+ // reconstruct is routed to the refine_index
50
+ void reconstruct (idx_t key, float * recons) const override;
51
+
52
+ ~IndexRefine() override;
53
+ };
54
+
55
+
56
+ /** Version where the refinement index is an IndexFlat. It has one additional
57
+ * constructor that takes a table of elements to add to the flat refinement
58
+ * index */
59
+ struct IndexRefineFlat: IndexRefine {
60
+ explicit IndexRefineFlat (Index *base_index);
61
+ IndexRefineFlat(Index *base_index, const float *xb);
62
+
63
+ IndexRefineFlat();
64
+
65
+ void search(
66
+ idx_t n, const float* x, idx_t k,
67
+ float* distances, idx_t* labels) const override;
68
+
69
+ };
70
+
71
+
72
+
73
+ } // namespace faiss