faiss 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +20 -2
|
@@ -9,8 +9,6 @@
|
|
|
9
9
|
|
|
10
10
|
#pragma once
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
12
|
#include <faiss/Index.h>
|
|
15
13
|
#include <faiss/VectorTransform.h>
|
|
16
14
|
|
|
@@ -18,21 +16,20 @@ namespace faiss {
|
|
|
18
16
|
|
|
19
17
|
/** Index that applies a LinearTransform transform on vectors before
|
|
20
18
|
* handing them over to a sub-index */
|
|
21
|
-
struct IndexPreTransform: Index {
|
|
19
|
+
struct IndexPreTransform : Index {
|
|
20
|
+
std::vector<VectorTransform*> chain; ///! chain of tranforms
|
|
21
|
+
Index* index; ///! the sub-index
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
Index * index; ///! the sub-index
|
|
23
|
+
bool own_fields; ///! whether pointers are deleted in destructor
|
|
25
24
|
|
|
26
|
-
|
|
25
|
+
explicit IndexPreTransform(Index* index);
|
|
27
26
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
IndexPreTransform ();
|
|
27
|
+
IndexPreTransform();
|
|
31
28
|
|
|
32
29
|
/// ltrans is the last transform before the index
|
|
33
|
-
IndexPreTransform
|
|
30
|
+
IndexPreTransform(VectorTransform* ltrans, Index* index);
|
|
34
31
|
|
|
35
|
-
void prepend_transform
|
|
32
|
+
void prepend_transform(VectorTransform* ltrans);
|
|
36
33
|
|
|
37
34
|
void train(idx_t n, const float* x) override;
|
|
38
35
|
|
|
@@ -47,47 +44,47 @@ struct IndexPreTransform: Index {
|
|
|
47
44
|
size_t remove_ids(const IDSelector& sel) override;
|
|
48
45
|
|
|
49
46
|
void search(
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
47
|
+
idx_t n,
|
|
48
|
+
const float* x,
|
|
49
|
+
idx_t k,
|
|
50
|
+
float* distances,
|
|
51
|
+
idx_t* labels) const override;
|
|
56
52
|
|
|
57
53
|
/* range search, no attempt is done to change the radius */
|
|
58
|
-
void range_search
|
|
59
|
-
|
|
54
|
+
void range_search(
|
|
55
|
+
idx_t n,
|
|
56
|
+
const float* x,
|
|
57
|
+
float radius,
|
|
58
|
+
RangeSearchResult* result) const override;
|
|
60
59
|
|
|
60
|
+
void reconstruct(idx_t key, float* recons) const override;
|
|
61
61
|
|
|
62
|
-
void
|
|
62
|
+
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
|
|
63
63
|
|
|
64
|
-
void
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
64
|
+
void search_and_reconstruct(
|
|
65
|
+
idx_t n,
|
|
66
|
+
const float* x,
|
|
67
|
+
idx_t k,
|
|
68
|
+
float* distances,
|
|
69
|
+
idx_t* labels,
|
|
70
|
+
float* recons) const override;
|
|
70
71
|
|
|
71
72
|
/// apply the transforms in the chain. The returned float * may be
|
|
72
73
|
/// equal to x, otherwise it should be deallocated.
|
|
73
|
-
const float
|
|
74
|
+
const float* apply_chain(idx_t n, const float* x) const;
|
|
74
75
|
|
|
75
76
|
/// Reverse the transforms in the chain. May not be implemented for
|
|
76
77
|
/// all transforms in the chain or may return approximate results.
|
|
77
|
-
void reverse_chain
|
|
78
|
-
|
|
78
|
+
void reverse_chain(idx_t n, const float* xt, float* x) const;
|
|
79
79
|
|
|
80
|
-
DistanceComputer
|
|
80
|
+
DistanceComputer* get_distance_computer() const override;
|
|
81
81
|
|
|
82
82
|
/* standalone codec interface */
|
|
83
|
-
size_t sa_code_size
|
|
84
|
-
void sa_encode
|
|
85
|
-
|
|
86
|
-
void sa_decode (idx_t n, const uint8_t *bytes,
|
|
87
|
-
float *x) const override;
|
|
83
|
+
size_t sa_code_size() const override;
|
|
84
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
85
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
88
86
|
|
|
89
87
|
~IndexPreTransform() override;
|
|
90
88
|
};
|
|
91
89
|
|
|
92
|
-
|
|
93
90
|
} // namespace faiss
|
|
@@ -5,63 +5,58 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
|
|
9
8
|
#include <faiss/IndexRefine.h>
|
|
10
9
|
|
|
10
|
+
#include <faiss/IndexFlat.h>
|
|
11
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
12
|
+
#include <faiss/impl/FaissAssert.h>
|
|
13
|
+
#include <faiss/utils/Heap.h>
|
|
11
14
|
#include <faiss/utils/distances.h>
|
|
12
15
|
#include <faiss/utils/utils.h>
|
|
13
|
-
#include <faiss/utils/Heap.h>
|
|
14
|
-
#include <faiss/impl/FaissAssert.h>
|
|
15
|
-
#include <faiss/impl/AuxIndexStructures.h>
|
|
16
|
-
#include <faiss/IndexFlat.h>
|
|
17
16
|
|
|
18
17
|
namespace faiss {
|
|
19
18
|
|
|
20
|
-
|
|
21
|
-
|
|
22
19
|
/***************************************************
|
|
23
20
|
* IndexRefine
|
|
24
21
|
***************************************************/
|
|
25
22
|
|
|
26
|
-
IndexRefine::IndexRefine
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
{
|
|
23
|
+
IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
|
|
24
|
+
: Index(base_index->d, base_index->metric_type),
|
|
25
|
+
base_index(base_index),
|
|
26
|
+
refine_index(refine_index) {
|
|
31
27
|
own_fields = own_refine_index = false;
|
|
32
28
|
if (refine_index != nullptr) {
|
|
33
|
-
FAISS_THROW_IF_NOT
|
|
34
|
-
FAISS_THROW_IF_NOT
|
|
29
|
+
FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
|
|
30
|
+
FAISS_THROW_IF_NOT(
|
|
31
|
+
base_index->metric_type == refine_index->metric_type);
|
|
35
32
|
is_trained = base_index->is_trained && refine_index->is_trained;
|
|
36
|
-
FAISS_THROW_IF_NOT
|
|
33
|
+
FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
|
|
37
34
|
} // other case is useful only to construct an IndexRefineFlat
|
|
38
35
|
ntotal = base_index->ntotal;
|
|
39
36
|
}
|
|
40
37
|
|
|
41
|
-
IndexRefine::IndexRefine
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
}
|
|
38
|
+
IndexRefine::IndexRefine()
|
|
39
|
+
: base_index(nullptr),
|
|
40
|
+
refine_index(nullptr),
|
|
41
|
+
own_fields(false),
|
|
42
|
+
own_refine_index(false) {}
|
|
46
43
|
|
|
47
|
-
void IndexRefine::train
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
refine_index->train (n, x);
|
|
44
|
+
void IndexRefine::train(idx_t n, const float* x) {
|
|
45
|
+
base_index->train(n, x);
|
|
46
|
+
refine_index->train(n, x);
|
|
51
47
|
is_trained = true;
|
|
52
48
|
}
|
|
53
49
|
|
|
54
|
-
void IndexRefine::add
|
|
55
|
-
FAISS_THROW_IF_NOT
|
|
56
|
-
base_index->add
|
|
57
|
-
refine_index->add
|
|
50
|
+
void IndexRefine::add(idx_t n, const float* x) {
|
|
51
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
52
|
+
base_index->add(n, x);
|
|
53
|
+
refine_index->add(n, x);
|
|
58
54
|
ntotal = refine_index->ntotal;
|
|
59
55
|
}
|
|
60
56
|
|
|
61
|
-
void IndexRefine::reset
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
refine_index->reset ();
|
|
57
|
+
void IndexRefine::reset() {
|
|
58
|
+
base_index->reset();
|
|
59
|
+
refine_index->reset();
|
|
65
60
|
ntotal = 0;
|
|
66
61
|
}
|
|
67
62
|
|
|
@@ -69,69 +64,72 @@ namespace {
|
|
|
69
64
|
|
|
70
65
|
typedef faiss::Index::idx_t idx_t;
|
|
71
66
|
|
|
72
|
-
template<class C>
|
|
73
|
-
static void reorder_2_heaps
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
67
|
+
template <class C>
|
|
68
|
+
static void reorder_2_heaps(
|
|
69
|
+
idx_t n,
|
|
70
|
+
idx_t k,
|
|
71
|
+
idx_t* labels,
|
|
72
|
+
float* distances,
|
|
73
|
+
idx_t k_base,
|
|
74
|
+
const idx_t* base_labels,
|
|
75
|
+
const float* base_distances) {
|
|
78
76
|
#pragma omp parallel for
|
|
79
77
|
for (idx_t i = 0; i < n; i++) {
|
|
80
|
-
idx_t
|
|
81
|
-
float
|
|
82
|
-
const idx_t
|
|
83
|
-
const float
|
|
78
|
+
idx_t* idxo = labels + i * k;
|
|
79
|
+
float* diso = distances + i * k;
|
|
80
|
+
const idx_t* idxi = base_labels + i * k_base;
|
|
81
|
+
const float* disi = base_distances + i * k_base;
|
|
84
82
|
|
|
85
|
-
heap_heapify<C>
|
|
83
|
+
heap_heapify<C>(k, diso, idxo, disi, idxi, k);
|
|
86
84
|
if (k_base != k) { // add remaining elements
|
|
87
|
-
heap_addn<C>
|
|
85
|
+
heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
|
|
88
86
|
}
|
|
89
|
-
heap_reorder<C>
|
|
87
|
+
heap_reorder<C>(k, diso, idxo);
|
|
90
88
|
}
|
|
91
89
|
}
|
|
92
90
|
|
|
93
|
-
|
|
94
91
|
} // anonymous namespace
|
|
95
92
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
{
|
|
102
|
-
FAISS_THROW_IF_NOT
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
93
|
+
void IndexRefine::search(
|
|
94
|
+
idx_t n,
|
|
95
|
+
const float* x,
|
|
96
|
+
idx_t k,
|
|
97
|
+
float* distances,
|
|
98
|
+
idx_t* labels) const {
|
|
99
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
100
|
+
|
|
101
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
102
|
+
idx_t k_base = idx_t(k * k_factor);
|
|
103
|
+
idx_t* base_labels = labels;
|
|
104
|
+
float* base_distances = distances;
|
|
106
105
|
ScopeDeleter<idx_t> del1;
|
|
107
106
|
ScopeDeleter<float> del2;
|
|
108
107
|
|
|
109
108
|
if (k != k_base) {
|
|
110
|
-
base_labels = new idx_t
|
|
111
|
-
del1.set
|
|
112
|
-
base_distances = new float
|
|
113
|
-
del2.set
|
|
109
|
+
base_labels = new idx_t[n * k_base];
|
|
110
|
+
del1.set(base_labels);
|
|
111
|
+
base_distances = new float[n * k_base];
|
|
112
|
+
del2.set(base_distances);
|
|
114
113
|
}
|
|
115
114
|
|
|
116
|
-
base_index->search
|
|
115
|
+
base_index->search(n, x, k_base, base_distances, base_labels);
|
|
117
116
|
|
|
118
117
|
for (int i = 0; i < n * k_base; i++)
|
|
119
|
-
assert
|
|
120
|
-
base_labels[i] < ntotal);
|
|
118
|
+
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
121
119
|
|
|
122
|
-
|
|
120
|
+
// parallelize over queries
|
|
123
121
|
#pragma omp parallel if (n > 1)
|
|
124
122
|
{
|
|
125
123
|
std::unique_ptr<DistanceComputer> dc(
|
|
126
|
-
|
|
127
|
-
);
|
|
124
|
+
refine_index->get_distance_computer());
|
|
128
125
|
#pragma omp for
|
|
129
126
|
for (idx_t i = 0; i < n; i++) {
|
|
130
127
|
dc->set_query(x + i * d);
|
|
131
128
|
idx_t ij = i * k_base;
|
|
132
129
|
for (idx_t j = 0; j < k_base; j++) {
|
|
133
130
|
idx_t idx = base_labels[ij];
|
|
134
|
-
if (idx < 0)
|
|
131
|
+
if (idx < 0)
|
|
132
|
+
break;
|
|
135
133
|
base_distances[ij] = (*dc)(idx);
|
|
136
134
|
ij++;
|
|
137
135
|
}
|
|
@@ -140,117 +138,103 @@ void IndexRefine::search (
|
|
|
140
138
|
|
|
141
139
|
// sort and store result
|
|
142
140
|
if (metric_type == METRIC_L2) {
|
|
143
|
-
typedef CMax
|
|
144
|
-
reorder_2_heaps<C>
|
|
145
|
-
|
|
146
|
-
k_base, base_labels, base_distances);
|
|
141
|
+
typedef CMax<float, idx_t> C;
|
|
142
|
+
reorder_2_heaps<C>(
|
|
143
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
147
144
|
|
|
148
145
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
149
|
-
typedef CMin
|
|
150
|
-
reorder_2_heaps<C>
|
|
151
|
-
|
|
152
|
-
k_base, base_labels, base_distances);
|
|
146
|
+
typedef CMin<float, idx_t> C;
|
|
147
|
+
reorder_2_heaps<C>(
|
|
148
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
153
149
|
} else {
|
|
154
150
|
FAISS_THROW_MSG("Metric type not supported");
|
|
155
151
|
}
|
|
156
|
-
|
|
157
152
|
}
|
|
158
153
|
|
|
159
|
-
void IndexRefine::reconstruct
|
|
160
|
-
refine_index->reconstruct
|
|
154
|
+
void IndexRefine::reconstruct(idx_t key, float* recons) const {
|
|
155
|
+
refine_index->reconstruct(key, recons);
|
|
161
156
|
}
|
|
162
157
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
if (own_fields) delete base_index;
|
|
169
|
-
if (own_refine_index) delete refine_index;
|
|
158
|
+
IndexRefine::~IndexRefine() {
|
|
159
|
+
if (own_fields)
|
|
160
|
+
delete base_index;
|
|
161
|
+
if (own_refine_index)
|
|
162
|
+
delete refine_index;
|
|
170
163
|
}
|
|
171
164
|
|
|
172
|
-
|
|
173
165
|
/***************************************************
|
|
174
166
|
* IndexRefineFlat
|
|
175
167
|
***************************************************/
|
|
176
168
|
|
|
177
|
-
IndexRefineFlat::IndexRefineFlat
|
|
178
|
-
|
|
179
|
-
|
|
169
|
+
IndexRefineFlat::IndexRefineFlat(Index* base_index)
|
|
170
|
+
: IndexRefine(
|
|
171
|
+
base_index,
|
|
172
|
+
new IndexFlat(base_index->d, base_index->metric_type)) {
|
|
180
173
|
is_trained = base_index->is_trained;
|
|
181
174
|
own_refine_index = true;
|
|
182
|
-
FAISS_THROW_IF_NOT_MSG
|
|
183
|
-
|
|
175
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
176
|
+
base_index->ntotal == 0,
|
|
177
|
+
"base_index should be empty in the beginning");
|
|
184
178
|
}
|
|
185
179
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
IndexRefine (base_index, nullptr)
|
|
189
|
-
{
|
|
180
|
+
IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
|
|
181
|
+
: IndexRefine(base_index, nullptr) {
|
|
190
182
|
is_trained = base_index->is_trained;
|
|
191
183
|
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
|
|
192
184
|
own_refine_index = true;
|
|
193
|
-
refine_index->add
|
|
194
|
-
|
|
185
|
+
refine_index->add(base_index->ntotal, xb);
|
|
195
186
|
}
|
|
196
187
|
|
|
197
|
-
IndexRefineFlat::IndexRefineFlat():
|
|
198
|
-
IndexRefine()
|
|
199
|
-
{
|
|
188
|
+
IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
|
|
200
189
|
own_refine_index = true;
|
|
201
190
|
}
|
|
202
191
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
192
|
+
void IndexRefineFlat::search(
|
|
193
|
+
idx_t n,
|
|
194
|
+
const float* x,
|
|
195
|
+
idx_t k,
|
|
196
|
+
float* distances,
|
|
197
|
+
idx_t* labels) const {
|
|
198
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
199
|
+
|
|
200
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
201
|
+
idx_t k_base = idx_t(k * k_factor);
|
|
202
|
+
idx_t* base_labels = labels;
|
|
203
|
+
float* base_distances = distances;
|
|
212
204
|
ScopeDeleter<idx_t> del1;
|
|
213
205
|
ScopeDeleter<float> del2;
|
|
214
206
|
|
|
215
207
|
if (k != k_base) {
|
|
216
|
-
base_labels = new idx_t
|
|
217
|
-
del1.set
|
|
218
|
-
base_distances = new float
|
|
219
|
-
del2.set
|
|
208
|
+
base_labels = new idx_t[n * k_base];
|
|
209
|
+
del1.set(base_labels);
|
|
210
|
+
base_distances = new float[n * k_base];
|
|
211
|
+
del2.set(base_distances);
|
|
220
212
|
}
|
|
221
213
|
|
|
222
|
-
base_index->search
|
|
214
|
+
base_index->search(n, x, k_base, base_distances, base_labels);
|
|
223
215
|
|
|
224
216
|
for (int i = 0; i < n * k_base; i++)
|
|
225
|
-
assert
|
|
226
|
-
base_labels[i] < ntotal);
|
|
217
|
+
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
227
218
|
|
|
228
219
|
// compute refined distances
|
|
229
|
-
auto rf = dynamic_cast<const IndexFlat
|
|
220
|
+
auto rf = dynamic_cast<const IndexFlat*>(refine_index);
|
|
230
221
|
FAISS_THROW_IF_NOT(rf);
|
|
231
222
|
|
|
232
|
-
rf->compute_distance_subset
|
|
233
|
-
n, x, k_base, base_distances, base_labels);
|
|
223
|
+
rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
|
|
234
224
|
|
|
235
225
|
// sort and store result
|
|
236
226
|
if (metric_type == METRIC_L2) {
|
|
237
|
-
typedef CMax
|
|
238
|
-
reorder_2_heaps<C>
|
|
239
|
-
|
|
240
|
-
k_base, base_labels, base_distances);
|
|
227
|
+
typedef CMax<float, idx_t> C;
|
|
228
|
+
reorder_2_heaps<C>(
|
|
229
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
241
230
|
|
|
242
231
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
243
|
-
typedef CMin
|
|
244
|
-
reorder_2_heaps<C>
|
|
245
|
-
|
|
246
|
-
k_base, base_labels, base_distances);
|
|
232
|
+
typedef CMin<float, idx_t> C;
|
|
233
|
+
reorder_2_heaps<C>(
|
|
234
|
+
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
247
235
|
} else {
|
|
248
236
|
FAISS_THROW_MSG("Metric type not supported");
|
|
249
237
|
}
|
|
250
|
-
|
|
251
238
|
}
|
|
252
239
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
240
|
} // namespace faiss
|