faiss 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,111 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #pragma once
10
+
11
+ #include <faiss/IndexPQ.h>
12
+ #include <faiss/impl/ProductQuantizer.h>
13
+ #include <faiss/utils/AlignedTable.h>
14
+
15
+
16
+ namespace faiss {
17
+
18
+
19
+ /** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
20
+ *
21
+ * The codes are not stored sequentially but grouped in blocks of size bbs.
22
+ * This makes it possible to compute distances quickly with SIMD instructions.
23
+ *
24
+ * Implementations:
25
+ * 12: blocked loop with internal loop on Q with qbs
26
+ * 13: same with reservoir accumulator to store results
27
+ * 14: no qbs with heap accumulator
28
+ * 15: no qbs with reservoir accumulator
29
+ */
30
+
31
+ struct IndexPQFastScan: Index {
32
+ ProductQuantizer pq;
33
+
34
+ // implementation to select
35
+ int implem = 0;
36
+ // skip some parts of the computation (for timing)
37
+ int skip = 0;
38
+
39
+ // size of the kernel
40
+ int bbs; // set at build time
41
+ int qbs = 0; // query block size 0 = use default
42
+
43
+ // packed version of the codes
44
+ size_t ntotal2;
45
+ size_t M2;
46
+
47
+ AlignedTable<uint8_t> codes;
48
+
49
+ // this is for testing purposes only (set when initialized by IndexPQ)
50
+ const uint8_t *orig_codes = nullptr;
51
+
52
+ IndexPQFastScan(
53
+ int d, size_t M, size_t nbits,
54
+ MetricType metric = METRIC_L2,
55
+ int bbs = 32
56
+ );
57
+
58
+ IndexPQFastScan();
59
+
60
+ /// build from an existing IndexPQ
61
+ explicit IndexPQFastScan(const IndexPQ & orig, int bbs = 32);
62
+
63
+ void train (idx_t n, const float *x) override;
64
+ void add (idx_t n, const float *x) override;
65
+ void reset() override ;
66
+ void search(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ idx_t* labels) const override;
72
+
73
+ // called by search function
74
+ void compute_quantized_LUT(
75
+ idx_t n, const float* x,
76
+ uint8_t *lut, float *normalizers) const ;
77
+
78
+ template<bool is_max>
79
+ void search_dispatch_implem(
80
+ idx_t n, const float* x, idx_t k,
81
+ float* distances, idx_t* labels) const;
82
+
83
+ template<class C>
84
+ void search_implem_2(
85
+ idx_t n, const float* x, idx_t k,
86
+ float* distances, idx_t* labels) const;
87
+
88
+
89
+ template<class C>
90
+ void search_implem_12(
91
+ idx_t n, const float* x, idx_t k,
92
+ float* distances, idx_t* labels, int impl) const;
93
+
94
+ template<class C>
95
+ void search_implem_14(
96
+ idx_t n, const float* x, idx_t k,
97
+ float* distances, idx_t* labels, int impl) const;
98
+
99
+ };
100
+
101
+ struct FastScanStats {
102
+ uint64_t t0, t1, t2, t3;
103
+ FastScanStats() {reset();}
104
+ void reset() {
105
+ memset(this, 0, sizeof(*this));
106
+ }
107
+ };
108
+
109
+ FAISS_API extern FastScanStats FastScan_stats;
110
+
111
+ } // namespace faiss
@@ -15,6 +15,7 @@
15
15
  #include <memory>
16
16
 
17
17
  #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/AuxIndexStructures.h>
18
19
 
19
20
  namespace faiss {
20
21
 
@@ -282,6 +283,52 @@ void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
282
283
  }
283
284
  }
284
285
 
286
+ namespace {
287
+
288
+ struct PreTransformDistanceComputer: DistanceComputer {
289
+ const IndexPreTransform *index;
290
+ std::unique_ptr<DistanceComputer> sub_dc;
291
+ std::unique_ptr<const float []> query;
292
+
293
+ explicit PreTransformDistanceComputer(const IndexPreTransform *index):
294
+ index(index),
295
+ sub_dc(index->index->get_distance_computer())
296
+ {}
297
+
298
+ void set_query(const float *x) override {
299
+ const float *xt = index->apply_chain (1, x);
300
+ if (xt == x) {
301
+ sub_dc->set_query (x);
302
+ } else {
303
+ query.reset(xt);
304
+ sub_dc->set_query (xt);
305
+ }
306
+ }
307
+
308
+ float symmetric_dis(idx_t i, idx_t j) override
309
+ {
310
+ return sub_dc->symmetric_dis(i, j);
311
+ }
312
+
313
+ float operator () (idx_t i) override
314
+ {
315
+ return (*sub_dc)(i);
316
+ }
317
+
318
+ };
319
+
320
+
321
+ } // anonymous namespace
322
+
323
+
324
+ DistanceComputer * IndexPreTransform::get_distance_computer() const {
325
+ if (chain.empty()) {
326
+ return index->get_distance_computer();
327
+ } else {
328
+ return new PreTransformDistanceComputer(this);
329
+ }
330
+ }
331
+
285
332
 
286
333
 
287
334
  } // namespace faiss
@@ -77,6 +77,8 @@ struct IndexPreTransform: Index {
77
77
  void reverse_chain (idx_t n, const float* xt, float* x) const;
78
78
 
79
79
 
80
+ DistanceComputer * get_distance_computer() const override;
81
+
80
82
  /* standalone codec interface */
81
83
  size_t sa_code_size () const override;
82
84
  void sa_encode (idx_t n, const float *x,
@@ -0,0 +1,256 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #include <faiss/IndexRefine.h>
10
+
11
+ #include <faiss/utils/distances.h>
12
+ #include <faiss/utils/utils.h>
13
+ #include <faiss/utils/Heap.h>
14
+ #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/impl/AuxIndexStructures.h>
16
+ #include <faiss/IndexFlat.h>
17
+
18
+ namespace faiss {
19
+
20
+
21
+
22
+ /***************************************************
23
+ * IndexRefine
24
+ ***************************************************/
25
+
26
+ IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
27
+ Index (base_index->d, base_index->metric_type),
28
+ base_index (base_index),
29
+ refine_index (refine_index)
30
+ {
31
+ own_fields = own_refine_index = false;
32
+ if (refine_index != nullptr) {
33
+ FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
34
+ FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
35
+ is_trained = base_index->is_trained && refine_index->is_trained;
36
+ FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
37
+ } // other case is useful only to construct an IndexRefineFlat
38
+ ntotal = base_index->ntotal;
39
+ }
40
+
41
+ IndexRefine::IndexRefine ():
42
+ base_index(nullptr), refine_index(nullptr),
43
+ own_fields(false), own_refine_index(false)
44
+ {
45
+ }
46
+
47
+ void IndexRefine::train (idx_t n, const float *x)
48
+ {
49
+ base_index->train (n, x);
50
+ refine_index->train (n, x);
51
+ is_trained = true;
52
+ }
53
+
54
+ void IndexRefine::add (idx_t n, const float *x) {
55
+ FAISS_THROW_IF_NOT (is_trained);
56
+ base_index->add (n, x);
57
+ refine_index->add (n, x);
58
+ ntotal = refine_index->ntotal;
59
+ }
60
+
61
+ void IndexRefine::reset ()
62
+ {
63
+ base_index->reset ();
64
+ refine_index->reset ();
65
+ ntotal = 0;
66
+ }
67
+
68
+ namespace {
69
+
70
+ typedef faiss::Index::idx_t idx_t;
71
+
72
+ template<class C>
73
+ static void reorder_2_heaps (
74
+ idx_t n,
75
+ idx_t k, idx_t *labels, float *distances,
76
+ idx_t k_base, const idx_t *base_labels, const float *base_distances)
77
+ {
78
+ #pragma omp parallel for
79
+ for (idx_t i = 0; i < n; i++) {
80
+ idx_t *idxo = labels + i * k;
81
+ float *diso = distances + i * k;
82
+ const idx_t *idxi = base_labels + i * k_base;
83
+ const float *disi = base_distances + i * k_base;
84
+
85
+ heap_heapify<C> (k, diso, idxo, disi, idxi, k);
86
+ if (k_base != k) { // add remaining elements
87
+ heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
88
+ }
89
+ heap_reorder<C> (k, diso, idxo);
90
+ }
91
+ }
92
+
93
+
94
+ } // anonymous namespace
95
+
96
+
97
+
98
+ void IndexRefine::search (
99
+ idx_t n, const float *x, idx_t k,
100
+ float *distances, idx_t *labels) const
101
+ {
102
+ FAISS_THROW_IF_NOT (is_trained);
103
+ idx_t k_base = idx_t (k * k_factor);
104
+ idx_t * base_labels = labels;
105
+ float * base_distances = distances;
106
+ ScopeDeleter<idx_t> del1;
107
+ ScopeDeleter<float> del2;
108
+
109
+ if (k != k_base) {
110
+ base_labels = new idx_t [n * k_base];
111
+ del1.set (base_labels);
112
+ base_distances = new float [n * k_base];
113
+ del2.set (base_distances);
114
+ }
115
+
116
+ base_index->search (n, x, k_base, base_distances, base_labels);
117
+
118
+ for (int i = 0; i < n * k_base; i++)
119
+ assert (base_labels[i] >= -1 &&
120
+ base_labels[i] < ntotal);
121
+
122
+ // parallelize over queries
123
+ #pragma omp parallel if (n > 1)
124
+ {
125
+ std::unique_ptr<DistanceComputer> dc(
126
+ refine_index->get_distance_computer()
127
+ );
128
+ #pragma omp for
129
+ for (idx_t i = 0; i < n; i++) {
130
+ dc->set_query(x + i * d);
131
+ idx_t ij = i * k_base;
132
+ for (idx_t j = 0; j < k_base; j++) {
133
+ idx_t idx = base_labels[ij];
134
+ if (idx < 0) break;
135
+ base_distances[ij] = (*dc)(idx);
136
+ ij++;
137
+ }
138
+ }
139
+ }
140
+
141
+ // sort and store result
142
+ if (metric_type == METRIC_L2) {
143
+ typedef CMax <float, idx_t> C;
144
+ reorder_2_heaps<C> (
145
+ n, k, labels, distances,
146
+ k_base, base_labels, base_distances);
147
+
148
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
149
+ typedef CMin <float, idx_t> C;
150
+ reorder_2_heaps<C> (
151
+ n, k, labels, distances,
152
+ k_base, base_labels, base_distances);
153
+ } else {
154
+ FAISS_THROW_MSG("Metric type not supported");
155
+ }
156
+
157
+ }
158
+
159
+ void IndexRefine::reconstruct (idx_t key, float * recons) const {
160
+ refine_index->reconstruct (key, recons);
161
+ }
162
+
163
+
164
+
165
+
166
+ IndexRefine::~IndexRefine ()
167
+ {
168
+ if (own_fields) delete base_index;
169
+ if (own_refine_index) delete refine_index;
170
+ }
171
+
172
+
173
+ /***************************************************
174
+ * IndexRefineFlat
175
+ ***************************************************/
176
+
177
+ IndexRefineFlat::IndexRefineFlat (Index *base_index):
178
+ IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
179
+ {
180
+ is_trained = base_index->is_trained;
181
+ own_refine_index = true;
182
+ FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
183
+ "base_index should be empty in the beginning");
184
+ }
185
+
186
+
187
+ IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
188
+ IndexRefine (base_index, nullptr)
189
+ {
190
+ is_trained = base_index->is_trained;
191
+ refine_index = new IndexFlat(base_index->d, base_index->metric_type);
192
+ own_refine_index = true;
193
+ refine_index->add (base_index->ntotal, xb);
194
+
195
+ }
196
+
197
+ IndexRefineFlat::IndexRefineFlat():
198
+ IndexRefine()
199
+ {
200
+ own_refine_index = true;
201
+ }
202
+
203
+
204
+ void IndexRefineFlat::search (
205
+ idx_t n, const float *x, idx_t k,
206
+ float *distances, idx_t *labels) const
207
+ {
208
+ FAISS_THROW_IF_NOT (is_trained);
209
+ idx_t k_base = idx_t (k * k_factor);
210
+ idx_t * base_labels = labels;
211
+ float * base_distances = distances;
212
+ ScopeDeleter<idx_t> del1;
213
+ ScopeDeleter<float> del2;
214
+
215
+ if (k != k_base) {
216
+ base_labels = new idx_t [n * k_base];
217
+ del1.set (base_labels);
218
+ base_distances = new float [n * k_base];
219
+ del2.set (base_distances);
220
+ }
221
+
222
+ base_index->search (n, x, k_base, base_distances, base_labels);
223
+
224
+ for (int i = 0; i < n * k_base; i++)
225
+ assert (base_labels[i] >= -1 &&
226
+ base_labels[i] < ntotal);
227
+
228
+ // compute refined distances
229
+ auto rf = dynamic_cast<const IndexFlat *>(refine_index);
230
+ FAISS_THROW_IF_NOT(rf);
231
+
232
+ rf->compute_distance_subset (
233
+ n, x, k_base, base_distances, base_labels);
234
+
235
+ // sort and store result
236
+ if (metric_type == METRIC_L2) {
237
+ typedef CMax <float, idx_t> C;
238
+ reorder_2_heaps<C> (
239
+ n, k, labels, distances,
240
+ k_base, base_labels, base_distances);
241
+
242
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
243
+ typedef CMin <float, idx_t> C;
244
+ reorder_2_heaps<C> (
245
+ n, k, labels, distances,
246
+ k_base, base_labels, base_distances);
247
+ } else {
248
+ FAISS_THROW_MSG("Metric type not supported");
249
+ }
250
+
251
+ }
252
+
253
+
254
+
255
+
256
+ } // namespace faiss
@@ -0,0 +1,73 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+
12
+
13
+ namespace faiss {
14
+
15
+
16
+ /** Index that queries in a base_index (a fast one) and refines the
17
+ * results with an exact search, hopefully improving the results.
18
+ */
19
+ struct IndexRefine: Index {
20
+
21
+ /// faster index to pre-select the vectors that should be filtered
22
+ Index *base_index;
23
+
24
+ /// refinement index
25
+ Index *refine_index;
26
+
27
+ bool own_fields; ///< should the base index be deallocated?
28
+ bool own_refine_index; ///< same with the refinement index
29
+
30
+ /// factor between k requested in search and the k requested from
31
+ /// the base_index (should be >= 1)
32
+ float k_factor = 1;
33
+
34
+ /// intitialize from empty index
35
+ IndexRefine (Index *base_index, Index *refine_index);
36
+
37
+ IndexRefine ();
38
+
39
+ void train(idx_t n, const float* x) override;
40
+
41
+ void add(idx_t n, const float* x) override;
42
+
43
+ void reset() override;
44
+
45
+ void search(
46
+ idx_t n, const float* x, idx_t k,
47
+ float* distances, idx_t* labels) const override;
48
+
49
+ // reconstruct is routed to the refine_index
50
+ void reconstruct (idx_t key, float * recons) const override;
51
+
52
+ ~IndexRefine() override;
53
+ };
54
+
55
+
56
+ /** Version where the refinement index is an IndexFlat. It has one additional
57
+ * constructor that takes a table of elements to add to the flat refinement
58
+ * index */
59
+ struct IndexRefineFlat: IndexRefine {
60
+ explicit IndexRefineFlat (Index *base_index);
61
+ IndexRefineFlat(Index *base_index, const float *xb);
62
+
63
+ IndexRefineFlat();
64
+
65
+ void search(
66
+ idx_t n, const float* x, idx_t k,
67
+ float* distances, idx_t* labels) const override;
68
+
69
+ };
70
+
71
+
72
+
73
+ } // namespace faiss