faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -9,8 +9,6 @@
|
|
|
9
9
|
|
|
10
10
|
#pragma once
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
12
|
#include <faiss/Index.h>
|
|
15
13
|
#include <faiss/VectorTransform.h>
|
|
16
14
|
|
|
@@ -18,21 +16,20 @@ namespace faiss {
|
|
|
18
16
|
|
|
19
17
|
/** Index that applies a LinearTransform transform on vectors before
|
|
20
18
|
* handing them over to a sub-index */
|
|
21
|
-
struct IndexPreTransform: Index {
|
|
19
|
+
struct IndexPreTransform : Index {
|
|
20
|
+
std::vector<VectorTransform*> chain; ///! chain of tranforms
|
|
21
|
+
Index* index; ///! the sub-index
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
Index * index; ///! the sub-index
|
|
23
|
+
bool own_fields; ///! whether pointers are deleted in destructor
|
|
25
24
|
|
|
26
|
-
|
|
25
|
+
explicit IndexPreTransform(Index* index);
|
|
27
26
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
IndexPreTransform ();
|
|
27
|
+
IndexPreTransform();
|
|
31
28
|
|
|
32
29
|
/// ltrans is the last transform before the index
|
|
33
|
-
IndexPreTransform
|
|
30
|
+
IndexPreTransform(VectorTransform* ltrans, Index* index);
|
|
34
31
|
|
|
35
|
-
void prepend_transform
|
|
32
|
+
void prepend_transform(VectorTransform* ltrans);
|
|
36
33
|
|
|
37
34
|
void train(idx_t n, const float* x) override;
|
|
38
35
|
|
|
@@ -47,47 +44,47 @@ struct IndexPreTransform: Index {
|
|
|
47
44
|
size_t remove_ids(const IDSelector& sel) override;
|
|
48
45
|
|
|
49
46
|
void search(
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
47
|
+
idx_t n,
|
|
48
|
+
const float* x,
|
|
49
|
+
idx_t k,
|
|
50
|
+
float* distances,
|
|
51
|
+
idx_t* labels) const override;
|
|
56
52
|
|
|
57
53
|
/* range search, no attempt is done to change the radius */
|
|
58
|
-
void range_search
|
|
59
|
-
|
|
54
|
+
void range_search(
|
|
55
|
+
idx_t n,
|
|
56
|
+
const float* x,
|
|
57
|
+
float radius,
|
|
58
|
+
RangeSearchResult* result) const override;
|
|
60
59
|
|
|
60
|
+
void reconstruct(idx_t key, float* recons) const override;
|
|
61
61
|
|
|
62
|
-
void
|
|
62
|
+
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
|
|
63
63
|
|
|
64
|
-
void
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
64
|
+
void search_and_reconstruct(
|
|
65
|
+
idx_t n,
|
|
66
|
+
const float* x,
|
|
67
|
+
idx_t k,
|
|
68
|
+
float* distances,
|
|
69
|
+
idx_t* labels,
|
|
70
|
+
float* recons) const override;
|
|
70
71
|
|
|
71
72
|
/// apply the transforms in the chain. The returned float * may be
|
|
72
73
|
/// equal to x, otherwise it should be deallocated.
|
|
73
|
-
const float
|
|
74
|
+
const float* apply_chain(idx_t n, const float* x) const;
|
|
74
75
|
|
|
75
76
|
/// Reverse the transforms in the chain. May not be implemented for
|
|
76
77
|
/// all transforms in the chain or may return approximate results.
|
|
77
|
-
void reverse_chain
|
|
78
|
-
|
|
78
|
+
void reverse_chain(idx_t n, const float* xt, float* x) const;
|
|
79
79
|
|
|
80
|
-
DistanceComputer
|
|
80
|
+
DistanceComputer* get_distance_computer() const override;
|
|
81
81
|
|
|
82
82
|
/* standalone codec interface */
|
|
83
|
-
size_t sa_code_size
|
|
84
|
-
void sa_encode
|
|
85
|
-
|
|
86
|
-
void sa_decode (idx_t n, const uint8_t *bytes,
|
|
87
|
-
float *x) const override;
|
|
83
|
+
size_t sa_code_size() const override;
|
|
84
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
85
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
88
86
|
|
|
89
87
|
~IndexPreTransform() override;
|
|
90
88
|
};
|
|
91
89
|
|
|
92
|
-
|
|
93
90
|
} // namespace faiss
|
|
@@ -5,63 +5,58 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
|
|
9
8
|
#include <faiss/IndexRefine.h>
|
|
10
9
|
|
|
10
|
+
#include <faiss/IndexFlat.h>
|
|
11
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
12
|
+
#include <faiss/impl/FaissAssert.h>
|
|
13
|
+
#include <faiss/utils/Heap.h>
|
|
11
14
|
#include <faiss/utils/distances.h>
|
|
12
15
|
#include <faiss/utils/utils.h>
|
|
13
|
-
#include <faiss/utils/Heap.h>
|
|
14
|
-
#include <faiss/impl/FaissAssert.h>
|
|
15
|
-
#include <faiss/impl/AuxIndexStructures.h>
|
|
16
|
-
#include <faiss/IndexFlat.h>
|
|
17
16
|
|
|
18
17
|
namespace faiss {
|
|
19
18
|
|
|
20
|
-
|
|
21
|
-
|
|
22
19
|
/***************************************************
|
|
23
20
|
* IndexRefine
|
|
24
21
|
***************************************************/
|
|
25
22
|
|
|
26
|
-
IndexRefine::IndexRefine
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
{
|
|
23
|
+
IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
|
|
24
|
+
: Index(base_index->d, base_index->metric_type),
|
|
25
|
+
base_index(base_index),
|
|
26
|
+
refine_index(refine_index) {
|
|
31
27
|
own_fields = own_refine_index = false;
|
|
32
28
|
if (refine_index != nullptr) {
|
|
33
|
-
FAISS_THROW_IF_NOT
|
|
34
|
-
FAISS_THROW_IF_NOT
|
|
29
|
+
FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
|
|
30
|
+
FAISS_THROW_IF_NOT(
|
|
31
|
+
base_index->metric_type == refine_index->metric_type);
|
|
35
32
|
is_trained = base_index->is_trained && refine_index->is_trained;
|
|
36
|
-
FAISS_THROW_IF_NOT
|
|
33
|
+
FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
|
|
37
34
|
} // other case is useful only to construct an IndexRefineFlat
|
|
38
35
|
ntotal = base_index->ntotal;
|
|
39
36
|
}
|
|
40
37
|
|
|
41
|
-
IndexRefine::IndexRefine
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
}
|
|
38
|
+
IndexRefine::IndexRefine()
|
|
39
|
+
: base_index(nullptr),
|
|
40
|
+
refine_index(nullptr),
|
|
41
|
+
own_fields(false),
|
|
42
|
+
own_refine_index(false) {}
|
|
46
43
|
|
|
47
|
-
void IndexRefine::train
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
refine_index->train (n, x);
|
|
44
|
+
void IndexRefine::train(idx_t n, const float* x) {
|
|
45
|
+
base_index->train(n, x);
|
|
46
|
+
refine_index->train(n, x);
|
|
51
47
|
is_trained = true;
|
|
52
48
|
}
|
|
53
49
|
|
|
54
|
-
void IndexRefine::add
|
|
55
|
-
FAISS_THROW_IF_NOT
|
|
56
|
-
base_index->add
|
|
57
|
-
refine_index->add
|
|
50
|
+
void IndexRefine::add(idx_t n, const float* x) {
|
|
51
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
52
|
+
base_index->add(n, x);
|
|
53
|
+
refine_index->add(n, x);
|
|
58
54
|
ntotal = refine_index->ntotal;
|
|
59
55
|
}
|
|
60
56
|
|
|
61
|
-
void IndexRefine::reset
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
refine_index->reset ();
|
|
57
|
+
void IndexRefine::reset() {
|
|
58
|
+
base_index->reset();
|
|
59
|
+
refine_index->reset();
|
|
65
60
|
ntotal = 0;
|
|
66
61
|
}
|
|
67
62
|
|
|
@@ -69,69 +64,72 @@ namespace {
|
|
|
69
64
|
|
|
70
65
|
typedef faiss::Index::idx_t idx_t;
|
|
71
66
|
|
|
72
|
-
template<class C>
|
|
73
|
-
static void reorder_2_heaps
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
67
|
+
template <class C>
|
|
68
|
+
static void reorder_2_heaps(
|
|
69
|
+
idx_t n,
|
|
70
|
+
idx_t k,
|
|
71
|
+
idx_t* labels,
|
|
72
|
+
float* distances,
|
|
73
|
+
idx_t k_base,
|
|
74
|
+
const idx_t* base_labels,
|
|
75
|
+
const float* base_distances) {
|
|
78
76
|
#pragma omp parallel for
|
|
79
77
|
for (idx_t i = 0; i < n; i++) {
|
|
80
|
-
idx_t
|
|
81
|
-
float
|
|
82
|
-
const idx_t
|
|
83
|
-
const float
|
|
78
|
+
idx_t* idxo = labels + i * k;
|
|
79
|
+
float* diso = distances + i * k;
|
|
80
|
+
const idx_t* idxi = base_labels + i * k_base;
|
|
81
|
+
const float* disi = base_distances + i * k_base;
|
|
84
82
|
|
|
85
|
-
heap_heapify<C>
|
|
83
|
+
heap_heapify<C>(k, diso, idxo, disi, idxi, k);
|
|
86
84
|
if (k_base != k) { // add remaining elements
|
|
87
|
-
heap_addn<C>
|
|
85
|
+
heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
|
|
88
86
|
}
|
|
89
|
-
heap_reorder<C>
|
|
87
|
+
heap_reorder<C>(k, diso, idxo);
|
|
90
88
|
}
|
|
91
89
|
}
|
|
92
90
|
|
|
93
|
-
|
|
94
91
|
} // anonymous namespace
|
|
95
92
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
{
|
|
102
|
-
FAISS_THROW_IF_NOT
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
93
|
+
void IndexRefine::search(
|
|
94
|
+
idx_t n,
|
|
95
|
+
const float* x,
|
|
96
|
+
idx_t k,
|
|
97
|
+
float* distances,
|
|
98
|
+
idx_t* labels) const {
|
|
99
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
100
|
+
|
|
101
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
102
|
+
idx_t k_base = idx_t(k * k_factor);
|
|
103
|
+
idx_t* base_labels = labels;
|
|
104
|
+
float* base_distances = distances;
|
|
106
105
|
ScopeDeleter<idx_t> del1;
|
|
107
106
|
ScopeDeleter<float> del2;
|
|
108
107
|
|
|
109
108
|
if (k != k_base) {
|
|
110
|
-
base_labels = new idx_t
|
|
111
|
-
del1.set
|
|
112
|
-
base_distances = new float
|
|
113
|
-
del2.set
|
|
109
|
+
base_labels = new idx_t[n * k_base];
|
|
110
|
+
del1.set(base_labels);
|
|
111
|
+
base_distances = new float[n * k_base];
|
|
112
|
+
del2.set(base_distances);
|
|
114
113
|
}
|
|
115
114
|
|
|
116
|
-
base_index->search
|
|
115
|
+
base_index->search(n, x, k_base, base_distances, base_labels);
|
|
117
116
|
|
|
118
117
|
for (int i = 0; i < n * k_base; i++)
|
|
119
|
-
assert
|
|
120
|
-
base_labels[i] < ntotal);
|
|
118
|
+
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
121
119
|
|
|
122
|
-
|
|
120
|
+
// parallelize over queries
|
|
123
121
|
#pragma omp parallel if (n > 1)
|
|
124
122
|
{
|
|
125
123
|
std::unique_ptr<DistanceComputer> dc(
|
|
126
|
-
|
|
127
|
-
);
|
|
124
|
+
refine_index->get_distance_computer());
|
|
128
125
|
#pragma omp for
|
|
129
126
|
for (idx_t i = 0; i < n; i++) {
|
|
130
127
|
dc->set_query(x + i * d);
|
|
131
128
|
idx_t ij = i * k_base;
|
|
132
129
|
for (idx_t j = 0; j < k_base; j++) {
|
|
133
130
|
idx_t idx = base_labels[ij];
|
|
134
|
-
if (idx < 0)
|
|
131
|
+
if (idx < 0)
|
|
132
|
+
break;
|
|
135
133
|
base_distances[ij] = (*dc)(idx);
|
|
136
134
|
ij++;
|
|
137
135
|
}
|
|
@@ -140,117 +138,131 @@ void IndexRefine::search (
|
|
|
140
138
|
|
|
141
139
|
// sort and store result
|
|
142
140
|
if (metric_type == METRIC_L2) {
|
|
143
|
-
typedef CMax
|
|
144
|
-
reorder_2_heaps<C>
|
|
145
|
-
|
|
146
|
-
k_base, base_labels, base_distances);
|
|
141
|
+
typedef CMax<float, idx_t> C;
|
|
142
|
+
reorder_2_heaps<C>(
|
|
143
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
147
144
|
|
|
148
145
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
149
|
-
typedef CMin
|
|
150
|
-
reorder_2_heaps<C>
|
|
151
|
-
|
|
152
|
-
k_base, base_labels, base_distances);
|
|
146
|
+
typedef CMin<float, idx_t> C;
|
|
147
|
+
reorder_2_heaps<C>(
|
|
148
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
153
149
|
} else {
|
|
154
150
|
FAISS_THROW_MSG("Metric type not supported");
|
|
155
151
|
}
|
|
156
|
-
|
|
157
152
|
}
|
|
158
153
|
|
|
159
|
-
void IndexRefine::reconstruct
|
|
160
|
-
refine_index->reconstruct
|
|
154
|
+
void IndexRefine::reconstruct(idx_t key, float* recons) const {
|
|
155
|
+
refine_index->reconstruct(key, recons);
|
|
161
156
|
}
|
|
162
157
|
|
|
158
|
+
size_t IndexRefine::sa_code_size() const {
|
|
159
|
+
return base_index->sa_code_size() + refine_index->sa_code_size();
|
|
160
|
+
}
|
|
163
161
|
|
|
162
|
+
void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
|
163
|
+
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
|
|
164
|
+
std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
|
|
165
|
+
base_index->sa_encode(n, x, tmp1.get());
|
|
166
|
+
std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
|
|
167
|
+
refine_index->sa_encode(n, x, tmp2.get());
|
|
168
|
+
for (size_t i = 0; i < n; i++) {
|
|
169
|
+
uint8_t* b = bytes + i * (cs1 + cs2);
|
|
170
|
+
memcpy(b, tmp1.get() + cs1 * i, cs1);
|
|
171
|
+
memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
|
|
172
|
+
}
|
|
173
|
+
}
|
|
164
174
|
|
|
175
|
+
void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
176
|
+
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
|
|
177
|
+
std::unique_ptr<uint8_t[]> tmp2(
|
|
178
|
+
new uint8_t[n * refine_index->sa_code_size()]);
|
|
179
|
+
for (size_t i = 0; i < n; i++) {
|
|
180
|
+
memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
|
|
181
|
+
}
|
|
165
182
|
|
|
166
|
-
|
|
167
|
-
{
|
|
168
|
-
if (own_fields) delete base_index;
|
|
169
|
-
if (own_refine_index) delete refine_index;
|
|
183
|
+
refine_index->sa_decode(n, tmp2.get(), x);
|
|
170
184
|
}
|
|
171
185
|
|
|
186
|
+
IndexRefine::~IndexRefine() {
|
|
187
|
+
if (own_fields)
|
|
188
|
+
delete base_index;
|
|
189
|
+
if (own_refine_index)
|
|
190
|
+
delete refine_index;
|
|
191
|
+
}
|
|
172
192
|
|
|
173
193
|
/***************************************************
|
|
174
194
|
* IndexRefineFlat
|
|
175
195
|
***************************************************/
|
|
176
196
|
|
|
177
|
-
IndexRefineFlat::IndexRefineFlat
|
|
178
|
-
|
|
179
|
-
|
|
197
|
+
IndexRefineFlat::IndexRefineFlat(Index* base_index)
|
|
198
|
+
: IndexRefine(
|
|
199
|
+
base_index,
|
|
200
|
+
new IndexFlat(base_index->d, base_index->metric_type)) {
|
|
180
201
|
is_trained = base_index->is_trained;
|
|
181
202
|
own_refine_index = true;
|
|
182
|
-
FAISS_THROW_IF_NOT_MSG
|
|
183
|
-
|
|
203
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
204
|
+
base_index->ntotal == 0,
|
|
205
|
+
"base_index should be empty in the beginning");
|
|
184
206
|
}
|
|
185
207
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
IndexRefine (base_index, nullptr)
|
|
189
|
-
{
|
|
208
|
+
IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
|
|
209
|
+
: IndexRefine(base_index, nullptr) {
|
|
190
210
|
is_trained = base_index->is_trained;
|
|
191
211
|
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
|
|
192
212
|
own_refine_index = true;
|
|
193
|
-
refine_index->add
|
|
194
|
-
|
|
213
|
+
refine_index->add(base_index->ntotal, xb);
|
|
195
214
|
}
|
|
196
215
|
|
|
197
|
-
IndexRefineFlat::IndexRefineFlat():
|
|
198
|
-
IndexRefine()
|
|
199
|
-
{
|
|
216
|
+
IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
|
|
200
217
|
own_refine_index = true;
|
|
201
218
|
}
|
|
202
219
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
220
|
+
void IndexRefineFlat::search(
|
|
221
|
+
idx_t n,
|
|
222
|
+
const float* x,
|
|
223
|
+
idx_t k,
|
|
224
|
+
float* distances,
|
|
225
|
+
idx_t* labels) const {
|
|
226
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
227
|
+
|
|
228
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
229
|
+
idx_t k_base = idx_t(k * k_factor);
|
|
230
|
+
idx_t* base_labels = labels;
|
|
231
|
+
float* base_distances = distances;
|
|
212
232
|
ScopeDeleter<idx_t> del1;
|
|
213
233
|
ScopeDeleter<float> del2;
|
|
214
234
|
|
|
215
235
|
if (k != k_base) {
|
|
216
|
-
base_labels = new idx_t
|
|
217
|
-
del1.set
|
|
218
|
-
base_distances = new float
|
|
219
|
-
del2.set
|
|
236
|
+
base_labels = new idx_t[n * k_base];
|
|
237
|
+
del1.set(base_labels);
|
|
238
|
+
base_distances = new float[n * k_base];
|
|
239
|
+
del2.set(base_distances);
|
|
220
240
|
}
|
|
221
241
|
|
|
222
|
-
base_index->search
|
|
242
|
+
base_index->search(n, x, k_base, base_distances, base_labels);
|
|
223
243
|
|
|
224
244
|
for (int i = 0; i < n * k_base; i++)
|
|
225
|
-
assert
|
|
226
|
-
base_labels[i] < ntotal);
|
|
245
|
+
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
227
246
|
|
|
228
247
|
// compute refined distances
|
|
229
|
-
auto rf = dynamic_cast<const IndexFlat
|
|
248
|
+
auto rf = dynamic_cast<const IndexFlat*>(refine_index);
|
|
230
249
|
FAISS_THROW_IF_NOT(rf);
|
|
231
250
|
|
|
232
|
-
rf->compute_distance_subset
|
|
233
|
-
n, x, k_base, base_distances, base_labels);
|
|
251
|
+
rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
|
|
234
252
|
|
|
235
253
|
// sort and store result
|
|
236
254
|
if (metric_type == METRIC_L2) {
|
|
237
|
-
typedef CMax
|
|
238
|
-
reorder_2_heaps<C>
|
|
239
|
-
|
|
240
|
-
k_base, base_labels, base_distances);
|
|
255
|
+
typedef CMax<float, idx_t> C;
|
|
256
|
+
reorder_2_heaps<C>(
|
|
257
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
241
258
|
|
|
242
259
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
243
|
-
typedef CMin
|
|
244
|
-
reorder_2_heaps<C>
|
|
245
|
-
|
|
246
|
-
k_base, base_labels, base_distances);
|
|
260
|
+
typedef CMin<float, idx_t> C;
|
|
261
|
+
reorder_2_heaps<C>(
|
|
262
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
247
263
|
} else {
|
|
248
264
|
FAISS_THROW_MSG("Metric type not supported");
|
|
249
265
|
}
|
|
250
|
-
|
|
251
266
|
}
|
|
252
267
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
268
|
} // namespace faiss
|
|
@@ -9,32 +9,29 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/Index.h>
|
|
11
11
|
|
|
12
|
-
|
|
13
12
|
namespace faiss {
|
|
14
13
|
|
|
15
|
-
|
|
16
14
|
/** Index that queries in a base_index (a fast one) and refines the
|
|
17
15
|
* results with an exact search, hopefully improving the results.
|
|
18
16
|
*/
|
|
19
|
-
struct IndexRefine: Index {
|
|
20
|
-
|
|
17
|
+
struct IndexRefine : Index {
|
|
21
18
|
/// faster index to pre-select the vectors that should be filtered
|
|
22
|
-
Index
|
|
19
|
+
Index* base_index;
|
|
23
20
|
|
|
24
21
|
/// refinement index
|
|
25
|
-
Index
|
|
22
|
+
Index* refine_index;
|
|
26
23
|
|
|
27
|
-
bool own_fields;
|
|
28
|
-
bool own_refine_index;
|
|
24
|
+
bool own_fields; ///< should the base index be deallocated?
|
|
25
|
+
bool own_refine_index; ///< same with the refinement index
|
|
29
26
|
|
|
30
27
|
/// factor between k requested in search and the k requested from
|
|
31
28
|
/// the base_index (should be >= 1)
|
|
32
29
|
float k_factor = 1;
|
|
33
30
|
|
|
34
|
-
///
|
|
35
|
-
IndexRefine
|
|
31
|
+
/// initialize from empty index
|
|
32
|
+
IndexRefine(Index* base_index, Index* refine_index);
|
|
36
33
|
|
|
37
|
-
IndexRefine
|
|
34
|
+
IndexRefine();
|
|
38
35
|
|
|
39
36
|
void train(idx_t n, const float* x) override;
|
|
40
37
|
|
|
@@ -43,31 +40,43 @@ struct IndexRefine: Index {
|
|
|
43
40
|
void reset() override;
|
|
44
41
|
|
|
45
42
|
void search(
|
|
46
|
-
|
|
47
|
-
|
|
43
|
+
idx_t n,
|
|
44
|
+
const float* x,
|
|
45
|
+
idx_t k,
|
|
46
|
+
float* distances,
|
|
47
|
+
idx_t* labels) const override;
|
|
48
48
|
|
|
49
49
|
// reconstruct is routed to the refine_index
|
|
50
|
-
void reconstruct
|
|
50
|
+
void reconstruct(idx_t key, float* recons) const override;
|
|
51
|
+
|
|
52
|
+
/* standalone codec interface: the base_index codes are interleaved with the
|
|
53
|
+
* refine_index ones */
|
|
54
|
+
size_t sa_code_size() const override;
|
|
55
|
+
|
|
56
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
57
|
+
|
|
58
|
+
/// The sa_decode decodes from the index_refine, which is assumed to be more
|
|
59
|
+
/// accurate
|
|
60
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
51
61
|
|
|
52
62
|
~IndexRefine() override;
|
|
53
63
|
};
|
|
54
64
|
|
|
55
|
-
|
|
56
65
|
/** Version where the refinement index is an IndexFlat. It has one additional
|
|
57
66
|
* constructor that takes a table of elements to add to the flat refinement
|
|
58
67
|
* index */
|
|
59
|
-
struct IndexRefineFlat: IndexRefine {
|
|
60
|
-
explicit IndexRefineFlat
|
|
61
|
-
IndexRefineFlat(Index
|
|
68
|
+
struct IndexRefineFlat : IndexRefine {
|
|
69
|
+
explicit IndexRefineFlat(Index* base_index);
|
|
70
|
+
IndexRefineFlat(Index* base_index, const float* xb);
|
|
62
71
|
|
|
63
72
|
IndexRefineFlat();
|
|
64
73
|
|
|
65
74
|
void search(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
75
|
+
idx_t n,
|
|
76
|
+
const float* x,
|
|
77
|
+
idx_t k,
|
|
78
|
+
float* distances,
|
|
79
|
+
idx_t* labels) const override;
|
|
69
80
|
};
|
|
70
81
|
|
|
71
|
-
|
|
72
|
-
|
|
73
82
|
} // namespace faiss
|