faiss 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +36 -33
- data/vendor/faiss/faiss/AutoTune.h +6 -3
- data/vendor/faiss/faiss/Clustering.cpp +16 -12
- data/vendor/faiss/faiss/Index.cpp +3 -4
- data/vendor/faiss/faiss/Index.h +3 -3
- data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
- data/vendor/faiss/faiss/IndexBinary.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
- data/vendor/faiss/faiss/IndexFlat.h +0 -51
- data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
- data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
- data/vendor/faiss/faiss/IndexIVF.h +22 -15
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
- data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
- data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
- data/vendor/faiss/faiss/IndexRefine.h +73 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
- data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
- data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
- data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
- data/vendor/faiss/faiss/impl/io.cpp +33 -2
- data/vendor/faiss/faiss/impl/io.h +7 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
- data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
- data/vendor/faiss/faiss/index_factory.cpp +112 -7
- data/vendor/faiss/faiss/index_io.h +1 -48
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
- data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
- data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
- data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
- data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
- data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
- data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
- data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
- data/vendor/faiss/faiss/utils/Heap.h +61 -50
- data/vendor/faiss/faiss/utils/distances.cpp +164 -319
- data/vendor/faiss/faiss/utils/distances.h +28 -20
- data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
- data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
- data/vendor/faiss/faiss/utils/hamming.h +2 -7
- data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
- data/vendor/faiss/faiss/utils/partitioning.h +69 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
- data/vendor/faiss/faiss/utils/simdlib.h +31 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
- metadata +43 -141
- data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
- data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
- data/vendor/faiss/c_api/AutoTune_c.h +0 -66
- data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
- data/vendor/faiss/c_api/Clustering_c.h +0 -123
- data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
- data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
- data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
- data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
- data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
- data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
- data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
- data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
- data/vendor/faiss/c_api/IndexShards_c.h +0 -39
- data/vendor/faiss/c_api/Index_c.cpp +0 -105
- data/vendor/faiss/c_api/Index_c.h +0 -183
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
- data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
- data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
- data/vendor/faiss/c_api/clone_index_c.h +0 -32
- data/vendor/faiss/c_api/error_c.h +0 -42
- data/vendor/faiss/c_api/error_impl.cpp +0 -27
- data/vendor/faiss/c_api/error_impl.h +0 -16
- data/vendor/faiss/c_api/faiss_c.h +0 -58
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
- data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
- data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
- data/vendor/faiss/c_api/index_factory_c.h +0 -30
- data/vendor/faiss/c_api/index_io_c.cpp +0 -42
- data/vendor/faiss/c_api/index_io_c.h +0 -50
- data/vendor/faiss/c_api/macros_impl.h +0 -110
- data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
- data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
- data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
- data/vendor/faiss/misc/test_blas.cpp +0 -87
- data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
- data/vendor/faiss/tests/test_merge.cpp +0 -260
- data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
- data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
- data/vendor/faiss/tests/test_params_override.cpp +0 -236
- data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
- data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
- data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
- data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,111 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <faiss/IndexPQ.h>
|
12
|
+
#include <faiss/impl/ProductQuantizer.h>
|
13
|
+
#include <faiss/utils/AlignedTable.h>
|
14
|
+
|
15
|
+
|
16
|
+
namespace faiss {
|
17
|
+
|
18
|
+
|
19
|
+
/** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
|
20
|
+
*
|
21
|
+
* The codes are not stored sequentially but grouped in blocks of size bbs.
|
22
|
+
* This makes it possible to compute distances quickly with SIMD instructions.
|
23
|
+
*
|
24
|
+
* Implementations:
|
25
|
+
* 12: blocked loop with internal loop on Q with qbs
|
26
|
+
* 13: same with reservoir accumulator to store results
|
27
|
+
* 14: no qbs with heap accumulator
|
28
|
+
* 15: no qbs with reservoir accumulator
|
29
|
+
*/
|
30
|
+
|
31
|
+
struct IndexPQFastScan: Index {
|
32
|
+
ProductQuantizer pq;
|
33
|
+
|
34
|
+
// implementation to select
|
35
|
+
int implem = 0;
|
36
|
+
// skip some parts of the computation (for timing)
|
37
|
+
int skip = 0;
|
38
|
+
|
39
|
+
// size of the kernel
|
40
|
+
int bbs; // set at build time
|
41
|
+
int qbs = 0; // query block size 0 = use default
|
42
|
+
|
43
|
+
// packed version of the codes
|
44
|
+
size_t ntotal2;
|
45
|
+
size_t M2;
|
46
|
+
|
47
|
+
AlignedTable<uint8_t> codes;
|
48
|
+
|
49
|
+
// this is for testing purposes only (set when initialized by IndexPQ)
|
50
|
+
const uint8_t *orig_codes = nullptr;
|
51
|
+
|
52
|
+
IndexPQFastScan(
|
53
|
+
int d, size_t M, size_t nbits,
|
54
|
+
MetricType metric = METRIC_L2,
|
55
|
+
int bbs = 32
|
56
|
+
);
|
57
|
+
|
58
|
+
IndexPQFastScan();
|
59
|
+
|
60
|
+
/// build from an existing IndexPQ
|
61
|
+
explicit IndexPQFastScan(const IndexPQ & orig, int bbs = 32);
|
62
|
+
|
63
|
+
void train (idx_t n, const float *x) override;
|
64
|
+
void add (idx_t n, const float *x) override;
|
65
|
+
void reset() override ;
|
66
|
+
void search(
|
67
|
+
idx_t n,
|
68
|
+
const float* x,
|
69
|
+
idx_t k,
|
70
|
+
float* distances,
|
71
|
+
idx_t* labels) const override;
|
72
|
+
|
73
|
+
// called by search function
|
74
|
+
void compute_quantized_LUT(
|
75
|
+
idx_t n, const float* x,
|
76
|
+
uint8_t *lut, float *normalizers) const ;
|
77
|
+
|
78
|
+
template<bool is_max>
|
79
|
+
void search_dispatch_implem(
|
80
|
+
idx_t n, const float* x, idx_t k,
|
81
|
+
float* distances, idx_t* labels) const;
|
82
|
+
|
83
|
+
template<class C>
|
84
|
+
void search_implem_2(
|
85
|
+
idx_t n, const float* x, idx_t k,
|
86
|
+
float* distances, idx_t* labels) const;
|
87
|
+
|
88
|
+
|
89
|
+
template<class C>
|
90
|
+
void search_implem_12(
|
91
|
+
idx_t n, const float* x, idx_t k,
|
92
|
+
float* distances, idx_t* labels, int impl) const;
|
93
|
+
|
94
|
+
template<class C>
|
95
|
+
void search_implem_14(
|
96
|
+
idx_t n, const float* x, idx_t k,
|
97
|
+
float* distances, idx_t* labels, int impl) const;
|
98
|
+
|
99
|
+
};
|
100
|
+
|
101
|
+
struct FastScanStats {
|
102
|
+
uint64_t t0, t1, t2, t3;
|
103
|
+
FastScanStats() {reset();}
|
104
|
+
void reset() {
|
105
|
+
memset(this, 0, sizeof(*this));
|
106
|
+
}
|
107
|
+
};
|
108
|
+
|
109
|
+
FAISS_API extern FastScanStats FastScan_stats;
|
110
|
+
|
111
|
+
} // namespace faiss
|
@@ -15,6 +15,7 @@
|
|
15
15
|
#include <memory>
|
16
16
|
|
17
17
|
#include <faiss/impl/FaissAssert.h>
|
18
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
18
19
|
|
19
20
|
namespace faiss {
|
20
21
|
|
@@ -282,6 +283,52 @@ void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
|
|
282
283
|
}
|
283
284
|
}
|
284
285
|
|
286
|
+
namespace {
|
287
|
+
|
288
|
+
struct PreTransformDistanceComputer: DistanceComputer {
|
289
|
+
const IndexPreTransform *index;
|
290
|
+
std::unique_ptr<DistanceComputer> sub_dc;
|
291
|
+
std::unique_ptr<const float []> query;
|
292
|
+
|
293
|
+
explicit PreTransformDistanceComputer(const IndexPreTransform *index):
|
294
|
+
index(index),
|
295
|
+
sub_dc(index->index->get_distance_computer())
|
296
|
+
{}
|
297
|
+
|
298
|
+
void set_query(const float *x) override {
|
299
|
+
const float *xt = index->apply_chain (1, x);
|
300
|
+
if (xt == x) {
|
301
|
+
sub_dc->set_query (x);
|
302
|
+
} else {
|
303
|
+
query.reset(xt);
|
304
|
+
sub_dc->set_query (xt);
|
305
|
+
}
|
306
|
+
}
|
307
|
+
|
308
|
+
float symmetric_dis(idx_t i, idx_t j) override
|
309
|
+
{
|
310
|
+
return sub_dc->symmetric_dis(i, j);
|
311
|
+
}
|
312
|
+
|
313
|
+
float operator () (idx_t i) override
|
314
|
+
{
|
315
|
+
return (*sub_dc)(i);
|
316
|
+
}
|
317
|
+
|
318
|
+
};
|
319
|
+
|
320
|
+
|
321
|
+
} // anonymous namespace
|
322
|
+
|
323
|
+
|
324
|
+
DistanceComputer * IndexPreTransform::get_distance_computer() const {
|
325
|
+
if (chain.empty()) {
|
326
|
+
return index->get_distance_computer();
|
327
|
+
} else {
|
328
|
+
return new PreTransformDistanceComputer(this);
|
329
|
+
}
|
330
|
+
}
|
331
|
+
|
285
332
|
|
286
333
|
|
287
334
|
} // namespace faiss
|
@@ -77,6 +77,8 @@ struct IndexPreTransform: Index {
|
|
77
77
|
void reverse_chain (idx_t n, const float* xt, float* x) const;
|
78
78
|
|
79
79
|
|
80
|
+
DistanceComputer * get_distance_computer() const override;
|
81
|
+
|
80
82
|
/* standalone codec interface */
|
81
83
|
size_t sa_code_size () const override;
|
82
84
|
void sa_encode (idx_t n, const float *x,
|
@@ -0,0 +1,256 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
|
9
|
+
#include <faiss/IndexRefine.h>
|
10
|
+
|
11
|
+
#include <faiss/utils/distances.h>
|
12
|
+
#include <faiss/utils/utils.h>
|
13
|
+
#include <faiss/utils/Heap.h>
|
14
|
+
#include <faiss/impl/FaissAssert.h>
|
15
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
16
|
+
#include <faiss/IndexFlat.h>
|
17
|
+
|
18
|
+
namespace faiss {
|
19
|
+
|
20
|
+
|
21
|
+
|
22
|
+
/***************************************************
|
23
|
+
* IndexRefine
|
24
|
+
***************************************************/
|
25
|
+
|
26
|
+
IndexRefine::IndexRefine (Index *base_index, Index *refine_index):
|
27
|
+
Index (base_index->d, base_index->metric_type),
|
28
|
+
base_index (base_index),
|
29
|
+
refine_index (refine_index)
|
30
|
+
{
|
31
|
+
own_fields = own_refine_index = false;
|
32
|
+
if (refine_index != nullptr) {
|
33
|
+
FAISS_THROW_IF_NOT (base_index->d == refine_index->d);
|
34
|
+
FAISS_THROW_IF_NOT (base_index->metric_type == refine_index->metric_type);
|
35
|
+
is_trained = base_index->is_trained && refine_index->is_trained;
|
36
|
+
FAISS_THROW_IF_NOT (base_index->ntotal == refine_index->ntotal);
|
37
|
+
} // other case is useful only to construct an IndexRefineFlat
|
38
|
+
ntotal = base_index->ntotal;
|
39
|
+
}
|
40
|
+
|
41
|
+
IndexRefine::IndexRefine ():
|
42
|
+
base_index(nullptr), refine_index(nullptr),
|
43
|
+
own_fields(false), own_refine_index(false)
|
44
|
+
{
|
45
|
+
}
|
46
|
+
|
47
|
+
void IndexRefine::train (idx_t n, const float *x)
|
48
|
+
{
|
49
|
+
base_index->train (n, x);
|
50
|
+
refine_index->train (n, x);
|
51
|
+
is_trained = true;
|
52
|
+
}
|
53
|
+
|
54
|
+
void IndexRefine::add (idx_t n, const float *x) {
|
55
|
+
FAISS_THROW_IF_NOT (is_trained);
|
56
|
+
base_index->add (n, x);
|
57
|
+
refine_index->add (n, x);
|
58
|
+
ntotal = refine_index->ntotal;
|
59
|
+
}
|
60
|
+
|
61
|
+
void IndexRefine::reset ()
|
62
|
+
{
|
63
|
+
base_index->reset ();
|
64
|
+
refine_index->reset ();
|
65
|
+
ntotal = 0;
|
66
|
+
}
|
67
|
+
|
68
|
+
namespace {
|
69
|
+
|
70
|
+
typedef faiss::Index::idx_t idx_t;
|
71
|
+
|
72
|
+
template<class C>
|
73
|
+
static void reorder_2_heaps (
|
74
|
+
idx_t n,
|
75
|
+
idx_t k, idx_t *labels, float *distances,
|
76
|
+
idx_t k_base, const idx_t *base_labels, const float *base_distances)
|
77
|
+
{
|
78
|
+
#pragma omp parallel for
|
79
|
+
for (idx_t i = 0; i < n; i++) {
|
80
|
+
idx_t *idxo = labels + i * k;
|
81
|
+
float *diso = distances + i * k;
|
82
|
+
const idx_t *idxi = base_labels + i * k_base;
|
83
|
+
const float *disi = base_distances + i * k_base;
|
84
|
+
|
85
|
+
heap_heapify<C> (k, diso, idxo, disi, idxi, k);
|
86
|
+
if (k_base != k) { // add remaining elements
|
87
|
+
heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
|
88
|
+
}
|
89
|
+
heap_reorder<C> (k, diso, idxo);
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
|
94
|
+
} // anonymous namespace
|
95
|
+
|
96
|
+
|
97
|
+
|
98
|
+
void IndexRefine::search (
|
99
|
+
idx_t n, const float *x, idx_t k,
|
100
|
+
float *distances, idx_t *labels) const
|
101
|
+
{
|
102
|
+
FAISS_THROW_IF_NOT (is_trained);
|
103
|
+
idx_t k_base = idx_t (k * k_factor);
|
104
|
+
idx_t * base_labels = labels;
|
105
|
+
float * base_distances = distances;
|
106
|
+
ScopeDeleter<idx_t> del1;
|
107
|
+
ScopeDeleter<float> del2;
|
108
|
+
|
109
|
+
if (k != k_base) {
|
110
|
+
base_labels = new idx_t [n * k_base];
|
111
|
+
del1.set (base_labels);
|
112
|
+
base_distances = new float [n * k_base];
|
113
|
+
del2.set (base_distances);
|
114
|
+
}
|
115
|
+
|
116
|
+
base_index->search (n, x, k_base, base_distances, base_labels);
|
117
|
+
|
118
|
+
for (int i = 0; i < n * k_base; i++)
|
119
|
+
assert (base_labels[i] >= -1 &&
|
120
|
+
base_labels[i] < ntotal);
|
121
|
+
|
122
|
+
// parallelize over queries
|
123
|
+
#pragma omp parallel if (n > 1)
|
124
|
+
{
|
125
|
+
std::unique_ptr<DistanceComputer> dc(
|
126
|
+
refine_index->get_distance_computer()
|
127
|
+
);
|
128
|
+
#pragma omp for
|
129
|
+
for (idx_t i = 0; i < n; i++) {
|
130
|
+
dc->set_query(x + i * d);
|
131
|
+
idx_t ij = i * k_base;
|
132
|
+
for (idx_t j = 0; j < k_base; j++) {
|
133
|
+
idx_t idx = base_labels[ij];
|
134
|
+
if (idx < 0) break;
|
135
|
+
base_distances[ij] = (*dc)(idx);
|
136
|
+
ij++;
|
137
|
+
}
|
138
|
+
}
|
139
|
+
}
|
140
|
+
|
141
|
+
// sort and store result
|
142
|
+
if (metric_type == METRIC_L2) {
|
143
|
+
typedef CMax <float, idx_t> C;
|
144
|
+
reorder_2_heaps<C> (
|
145
|
+
n, k, labels, distances,
|
146
|
+
k_base, base_labels, base_distances);
|
147
|
+
|
148
|
+
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
149
|
+
typedef CMin <float, idx_t> C;
|
150
|
+
reorder_2_heaps<C> (
|
151
|
+
n, k, labels, distances,
|
152
|
+
k_base, base_labels, base_distances);
|
153
|
+
} else {
|
154
|
+
FAISS_THROW_MSG("Metric type not supported");
|
155
|
+
}
|
156
|
+
|
157
|
+
}
|
158
|
+
|
159
|
+
void IndexRefine::reconstruct (idx_t key, float * recons) const {
|
160
|
+
refine_index->reconstruct (key, recons);
|
161
|
+
}
|
162
|
+
|
163
|
+
|
164
|
+
|
165
|
+
|
166
|
+
IndexRefine::~IndexRefine ()
|
167
|
+
{
|
168
|
+
if (own_fields) delete base_index;
|
169
|
+
if (own_refine_index) delete refine_index;
|
170
|
+
}
|
171
|
+
|
172
|
+
|
173
|
+
/***************************************************
|
174
|
+
* IndexRefineFlat
|
175
|
+
***************************************************/
|
176
|
+
|
177
|
+
IndexRefineFlat::IndexRefineFlat (Index *base_index):
|
178
|
+
IndexRefine(base_index, new IndexFlat(base_index->d, base_index->metric_type))
|
179
|
+
{
|
180
|
+
is_trained = base_index->is_trained;
|
181
|
+
own_refine_index = true;
|
182
|
+
FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
|
183
|
+
"base_index should be empty in the beginning");
|
184
|
+
}
|
185
|
+
|
186
|
+
|
187
|
+
IndexRefineFlat::IndexRefineFlat (Index *base_index, const float *xb):
|
188
|
+
IndexRefine (base_index, nullptr)
|
189
|
+
{
|
190
|
+
is_trained = base_index->is_trained;
|
191
|
+
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
|
192
|
+
own_refine_index = true;
|
193
|
+
refine_index->add (base_index->ntotal, xb);
|
194
|
+
|
195
|
+
}
|
196
|
+
|
197
|
+
IndexRefineFlat::IndexRefineFlat():
|
198
|
+
IndexRefine()
|
199
|
+
{
|
200
|
+
own_refine_index = true;
|
201
|
+
}
|
202
|
+
|
203
|
+
|
204
|
+
void IndexRefineFlat::search (
|
205
|
+
idx_t n, const float *x, idx_t k,
|
206
|
+
float *distances, idx_t *labels) const
|
207
|
+
{
|
208
|
+
FAISS_THROW_IF_NOT (is_trained);
|
209
|
+
idx_t k_base = idx_t (k * k_factor);
|
210
|
+
idx_t * base_labels = labels;
|
211
|
+
float * base_distances = distances;
|
212
|
+
ScopeDeleter<idx_t> del1;
|
213
|
+
ScopeDeleter<float> del2;
|
214
|
+
|
215
|
+
if (k != k_base) {
|
216
|
+
base_labels = new idx_t [n * k_base];
|
217
|
+
del1.set (base_labels);
|
218
|
+
base_distances = new float [n * k_base];
|
219
|
+
del2.set (base_distances);
|
220
|
+
}
|
221
|
+
|
222
|
+
base_index->search (n, x, k_base, base_distances, base_labels);
|
223
|
+
|
224
|
+
for (int i = 0; i < n * k_base; i++)
|
225
|
+
assert (base_labels[i] >= -1 &&
|
226
|
+
base_labels[i] < ntotal);
|
227
|
+
|
228
|
+
// compute refined distances
|
229
|
+
auto rf = dynamic_cast<const IndexFlat *>(refine_index);
|
230
|
+
FAISS_THROW_IF_NOT(rf);
|
231
|
+
|
232
|
+
rf->compute_distance_subset (
|
233
|
+
n, x, k_base, base_distances, base_labels);
|
234
|
+
|
235
|
+
// sort and store result
|
236
|
+
if (metric_type == METRIC_L2) {
|
237
|
+
typedef CMax <float, idx_t> C;
|
238
|
+
reorder_2_heaps<C> (
|
239
|
+
n, k, labels, distances,
|
240
|
+
k_base, base_labels, base_distances);
|
241
|
+
|
242
|
+
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
243
|
+
typedef CMin <float, idx_t> C;
|
244
|
+
reorder_2_heaps<C> (
|
245
|
+
n, k, labels, distances,
|
246
|
+
k_base, base_labels, base_distances);
|
247
|
+
} else {
|
248
|
+
FAISS_THROW_MSG("Metric type not supported");
|
249
|
+
}
|
250
|
+
|
251
|
+
}
|
252
|
+
|
253
|
+
|
254
|
+
|
255
|
+
|
256
|
+
} // namespace faiss
|
@@ -0,0 +1,73 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <faiss/Index.h>
|
11
|
+
|
12
|
+
|
13
|
+
namespace faiss {
|
14
|
+
|
15
|
+
|
16
|
+
/** Index that queries in a base_index (a fast one) and refines the
|
17
|
+
* results with an exact search, hopefully improving the results.
|
18
|
+
*/
|
19
|
+
struct IndexRefine: Index {
|
20
|
+
|
21
|
+
/// faster index to pre-select the vectors that should be filtered
|
22
|
+
Index *base_index;
|
23
|
+
|
24
|
+
/// refinement index
|
25
|
+
Index *refine_index;
|
26
|
+
|
27
|
+
bool own_fields; ///< should the base index be deallocated?
|
28
|
+
bool own_refine_index; ///< same with the refinement index
|
29
|
+
|
30
|
+
/// factor between k requested in search and the k requested from
|
31
|
+
/// the base_index (should be >= 1)
|
32
|
+
float k_factor = 1;
|
33
|
+
|
34
|
+
/// intitialize from empty index
|
35
|
+
IndexRefine (Index *base_index, Index *refine_index);
|
36
|
+
|
37
|
+
IndexRefine ();
|
38
|
+
|
39
|
+
void train(idx_t n, const float* x) override;
|
40
|
+
|
41
|
+
void add(idx_t n, const float* x) override;
|
42
|
+
|
43
|
+
void reset() override;
|
44
|
+
|
45
|
+
void search(
|
46
|
+
idx_t n, const float* x, idx_t k,
|
47
|
+
float* distances, idx_t* labels) const override;
|
48
|
+
|
49
|
+
// reconstruct is routed to the refine_index
|
50
|
+
void reconstruct (idx_t key, float * recons) const override;
|
51
|
+
|
52
|
+
~IndexRefine() override;
|
53
|
+
};
|
54
|
+
|
55
|
+
|
56
|
+
/** Version where the refinement index is an IndexFlat. It has one additional
|
57
|
+
* constructor that takes a table of elements to add to the flat refinement
|
58
|
+
* index */
|
59
|
+
struct IndexRefineFlat: IndexRefine {
|
60
|
+
explicit IndexRefineFlat (Index *base_index);
|
61
|
+
IndexRefineFlat(Index *base_index, const float *xb);
|
62
|
+
|
63
|
+
IndexRefineFlat();
|
64
|
+
|
65
|
+
void search(
|
66
|
+
idx_t n, const float* x, idx_t k,
|
67
|
+
float* distances, idx_t* labels) const override;
|
68
|
+
|
69
|
+
};
|
70
|
+
|
71
|
+
|
72
|
+
|
73
|
+
} // namespace faiss
|