faiss 0.1.7 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/README.md +7 -7
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +8 -2
- data/ext/faiss/index.cpp +102 -69
- data/ext/faiss/index_binary.cpp +24 -30
- data/ext/faiss/kmeans.cpp +20 -16
- data/ext/faiss/numo.hpp +867 -0
- data/ext/faiss/pca_matrix.cpp +13 -14
- data/ext/faiss/product_quantizer.cpp +23 -24
- data/ext/faiss/utils.cpp +10 -37
- data/ext/faiss/utils.h +2 -13
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +0 -5
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +26 -12
- data/lib/faiss/index.rb +0 -20
- data/lib/faiss/index_binary.rb +0 -20
- data/lib/faiss/kmeans.rb +0 -15
- data/lib/faiss/pca_matrix.rb +0 -15
- data/lib/faiss/product_quantizer.rb +0 -22
@@ -5,17 +5,14 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
|
9
8
|
#pragma once
|
10
9
|
|
11
10
|
#include <faiss/IndexPQ.h>
|
12
11
|
#include <faiss/impl/ProductQuantizer.h>
|
13
12
|
#include <faiss/utils/AlignedTable.h>
|
14
13
|
|
15
|
-
|
16
14
|
namespace faiss {
|
17
15
|
|
18
|
-
|
19
16
|
/** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
|
20
17
|
*
|
21
18
|
* The codes are not stored sequentially but grouped in blocks of size bbs.
|
@@ -28,7 +25,7 @@ namespace faiss {
|
|
28
25
|
* 15: no qbs with reservoir accumulator
|
29
26
|
*/
|
30
27
|
|
31
|
-
struct IndexPQFastScan: Index
|
28
|
+
struct IndexPQFastScan : Index {
|
32
29
|
ProductQuantizer pq;
|
33
30
|
|
34
31
|
// implementation to select
|
@@ -37,8 +34,8 @@ struct IndexPQFastScan: Index {
|
|
37
34
|
int skip = 0;
|
38
35
|
|
39
36
|
// size of the kernel
|
40
|
-
int bbs;
|
41
|
-
int qbs = 0;
|
37
|
+
int bbs; // set at build time
|
38
|
+
int qbs = 0; // query block size 0 = use default
|
42
39
|
|
43
40
|
// packed version of the codes
|
44
41
|
size_t ntotal2;
|
@@ -47,22 +44,23 @@ struct IndexPQFastScan: Index {
|
|
47
44
|
AlignedTable<uint8_t> codes;
|
48
45
|
|
49
46
|
// this is for testing purposes only (set when initialized by IndexPQ)
|
50
|
-
const uint8_t
|
47
|
+
const uint8_t* orig_codes = nullptr;
|
51
48
|
|
52
49
|
IndexPQFastScan(
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
50
|
+
int d,
|
51
|
+
size_t M,
|
52
|
+
size_t nbits,
|
53
|
+
MetricType metric = METRIC_L2,
|
54
|
+
int bbs = 32);
|
57
55
|
|
58
56
|
IndexPQFastScan();
|
59
57
|
|
60
58
|
/// build from an existing IndexPQ
|
61
|
-
explicit IndexPQFastScan(const IndexPQ
|
59
|
+
explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
|
62
60
|
|
63
|
-
void train
|
64
|
-
void add
|
65
|
-
void reset() override
|
61
|
+
void train(idx_t n, const float* x) override;
|
62
|
+
void add(idx_t n, const float* x) override;
|
63
|
+
void reset() override;
|
66
64
|
void search(
|
67
65
|
idx_t n,
|
68
66
|
const float* x,
|
@@ -72,35 +70,51 @@ struct IndexPQFastScan: Index {
|
|
72
70
|
|
73
71
|
// called by search function
|
74
72
|
void compute_quantized_LUT(
|
75
|
-
idx_t n,
|
76
|
-
|
73
|
+
idx_t n,
|
74
|
+
const float* x,
|
75
|
+
uint8_t* lut,
|
76
|
+
float* normalizers) const;
|
77
77
|
|
78
|
-
template<bool is_max>
|
78
|
+
template <bool is_max>
|
79
79
|
void search_dispatch_implem(
|
80
|
-
idx_t n,
|
81
|
-
float*
|
80
|
+
idx_t n,
|
81
|
+
const float* x,
|
82
|
+
idx_t k,
|
83
|
+
float* distances,
|
84
|
+
idx_t* labels) const;
|
82
85
|
|
83
|
-
template<class C>
|
86
|
+
template <class C>
|
84
87
|
void search_implem_2(
|
85
|
-
idx_t n,
|
86
|
-
float*
|
87
|
-
|
88
|
+
idx_t n,
|
89
|
+
const float* x,
|
90
|
+
idx_t k,
|
91
|
+
float* distances,
|
92
|
+
idx_t* labels) const;
|
88
93
|
|
89
|
-
template<class C>
|
94
|
+
template <class C>
|
90
95
|
void search_implem_12(
|
91
|
-
idx_t n,
|
92
|
-
float*
|
96
|
+
idx_t n,
|
97
|
+
const float* x,
|
98
|
+
idx_t k,
|
99
|
+
float* distances,
|
100
|
+
idx_t* labels,
|
101
|
+
int impl) const;
|
93
102
|
|
94
|
-
template<class C>
|
103
|
+
template <class C>
|
95
104
|
void search_implem_14(
|
96
|
-
idx_t n,
|
97
|
-
float*
|
98
|
-
|
105
|
+
idx_t n,
|
106
|
+
const float* x,
|
107
|
+
idx_t k,
|
108
|
+
float* distances,
|
109
|
+
idx_t* labels,
|
110
|
+
int impl) const;
|
99
111
|
};
|
100
112
|
|
101
113
|
struct FastScanStats {
|
102
114
|
uint64_t t0, t1, t2, t3;
|
103
|
-
FastScanStats() {
|
115
|
+
FastScanStats() {
|
116
|
+
reset();
|
117
|
+
}
|
104
118
|
void reset() {
|
105
119
|
memset(this, 0, sizeof(*this));
|
106
120
|
}
|
@@ -9,13 +9,13 @@
|
|
9
9
|
|
10
10
|
#include <faiss/IndexPreTransform.h>
|
11
11
|
|
12
|
-
#include <cstdio>
|
13
12
|
#include <cmath>
|
13
|
+
#include <cstdio>
|
14
14
|
#include <cstring>
|
15
15
|
#include <memory>
|
16
16
|
|
17
|
-
#include <faiss/impl/FaissAssert.h>
|
18
17
|
#include <faiss/impl/AuxIndexStructures.h>
|
18
|
+
#include <faiss/impl/FaissAssert.h>
|
19
19
|
|
20
20
|
namespace faiss {
|
21
21
|
|
@@ -23,44 +23,29 @@ namespace faiss {
|
|
23
23
|
* IndexPreTransform
|
24
24
|
*********************************************/
|
25
25
|
|
26
|
-
IndexPreTransform::IndexPreTransform ()
|
27
|
-
index(nullptr), own_fields (false)
|
28
|
-
{
|
29
|
-
}
|
30
|
-
|
26
|
+
IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {}
|
31
27
|
|
32
|
-
IndexPreTransform::IndexPreTransform
|
33
|
-
Index
|
34
|
-
Index (index->d, index->metric_type),
|
35
|
-
index (index), own_fields (false)
|
36
|
-
{
|
28
|
+
IndexPreTransform::IndexPreTransform(Index* index)
|
29
|
+
: Index(index->d, index->metric_type), index(index), own_fields(false) {
|
37
30
|
is_trained = index->is_trained;
|
38
31
|
ntotal = index->ntotal;
|
39
32
|
}
|
40
33
|
|
41
|
-
|
42
|
-
|
43
|
-
VectorTransform * ltrans,
|
44
|
-
Index * index):
|
45
|
-
Index (index->d, index->metric_type),
|
46
|
-
index (index), own_fields (false)
|
47
|
-
{
|
34
|
+
IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index)
|
35
|
+
: Index(index->d, index->metric_type), index(index), own_fields(false) {
|
48
36
|
is_trained = index->is_trained;
|
49
37
|
ntotal = index->ntotal;
|
50
|
-
prepend_transform
|
38
|
+
prepend_transform(ltrans);
|
51
39
|
}
|
52
40
|
|
53
|
-
void IndexPreTransform::prepend_transform
|
54
|
-
|
55
|
-
FAISS_THROW_IF_NOT (ltrans->d_out == d);
|
41
|
+
void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
|
42
|
+
FAISS_THROW_IF_NOT(ltrans->d_out == d);
|
56
43
|
is_trained = is_trained && ltrans->is_trained;
|
57
|
-
chain.insert
|
44
|
+
chain.insert(chain.begin(), ltrans);
|
58
45
|
d = ltrans->d_in;
|
59
46
|
}
|
60
47
|
|
61
|
-
|
62
|
-
IndexPreTransform::~IndexPreTransform ()
|
63
|
-
{
|
48
|
+
IndexPreTransform::~IndexPreTransform() {
|
64
49
|
if (own_fields) {
|
65
50
|
for (int i = 0; i < chain.size(); i++)
|
66
51
|
delete chain[i];
|
@@ -68,11 +53,7 @@ IndexPreTransform::~IndexPreTransform ()
|
|
68
53
|
}
|
69
54
|
}
|
70
55
|
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
void IndexPreTransform::train (idx_t n, const float *x)
|
75
|
-
{
|
56
|
+
void IndexPreTransform::train(idx_t n, const float* x) {
|
76
57
|
int last_untrained = 0;
|
77
58
|
if (!index->is_trained) {
|
78
59
|
last_untrained = chain.size();
|
@@ -84,7 +65,7 @@ void IndexPreTransform::train (idx_t n, const float *x)
|
|
84
65
|
}
|
85
66
|
}
|
86
67
|
}
|
87
|
-
const float
|
68
|
+
const float* prev_x = x;
|
88
69
|
ScopeDeleter<float> del;
|
89
70
|
|
90
71
|
if (verbose) {
|
@@ -93,34 +74,35 @@ void IndexPreTransform::train (idx_t n, const float *x)
|
|
93
74
|
}
|
94
75
|
|
95
76
|
for (int i = 0; i <= last_untrained; i++) {
|
96
|
-
|
97
77
|
if (i < chain.size()) {
|
98
|
-
VectorTransform
|
78
|
+
VectorTransform* ltrans = chain[i];
|
99
79
|
if (!ltrans->is_trained) {
|
100
80
|
if (verbose) {
|
101
81
|
printf(" Training chain component %d/%zd\n",
|
102
|
-
i,
|
103
|
-
|
82
|
+
i,
|
83
|
+
chain.size());
|
84
|
+
if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
|
104
85
|
opqm->verbose = true;
|
105
86
|
}
|
106
87
|
}
|
107
|
-
ltrans->train
|
88
|
+
ltrans->train(n, prev_x);
|
108
89
|
}
|
109
90
|
} else {
|
110
91
|
if (verbose) {
|
111
92
|
printf(" Training sub-index\n");
|
112
93
|
}
|
113
|
-
index->train
|
94
|
+
index->train(n, prev_x);
|
114
95
|
}
|
115
|
-
if (i == last_untrained)
|
96
|
+
if (i == last_untrained)
|
97
|
+
break;
|
116
98
|
if (verbose) {
|
117
|
-
printf(" Applying transform %d/%zd\n",
|
118
|
-
i, chain.size());
|
99
|
+
printf(" Applying transform %d/%zd\n", i, chain.size());
|
119
100
|
}
|
120
101
|
|
121
|
-
float
|
102
|
+
float* xt = chain[i]->apply(n, prev_x);
|
122
103
|
|
123
|
-
if (prev_x != x)
|
104
|
+
if (prev_x != x)
|
105
|
+
delete[] prev_x;
|
124
106
|
prev_x = xt;
|
125
107
|
del.set(xt);
|
126
108
|
}
|
@@ -128,200 +110,190 @@ void IndexPreTransform::train (idx_t n, const float *x)
|
|
128
110
|
is_trained = true;
|
129
111
|
}
|
130
112
|
|
131
|
-
|
132
|
-
const float
|
133
|
-
{
|
134
|
-
const float *prev_x = x;
|
113
|
+
const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
|
114
|
+
const float* prev_x = x;
|
135
115
|
ScopeDeleter<float> del;
|
136
116
|
|
137
117
|
for (int i = 0; i < chain.size(); i++) {
|
138
|
-
float
|
139
|
-
ScopeDeleter<float> del2
|
140
|
-
del2.swap
|
118
|
+
float* xt = chain[i]->apply(n, prev_x);
|
119
|
+
ScopeDeleter<float> del2(xt);
|
120
|
+
del2.swap(del);
|
141
121
|
prev_x = xt;
|
142
122
|
}
|
143
|
-
del.release
|
123
|
+
del.release();
|
144
124
|
return prev_x;
|
145
125
|
}
|
146
126
|
|
147
|
-
void IndexPreTransform::reverse_chain
|
148
|
-
{
|
127
|
+
void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
|
128
|
+
const {
|
149
129
|
const float* next_x = xt;
|
150
130
|
ScopeDeleter<float> del;
|
151
131
|
|
152
132
|
for (int i = chain.size() - 1; i >= 0; i--) {
|
153
|
-
float* prev_x = (i == 0) ? x : new float
|
154
|
-
ScopeDeleter<float> del2
|
155
|
-
chain
|
156
|
-
del2.swap
|
133
|
+
float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
|
134
|
+
ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
|
135
|
+
chain[i]->reverse_transform(n, next_x, prev_x);
|
136
|
+
del2.swap(del);
|
157
137
|
next_x = prev_x;
|
158
138
|
}
|
159
139
|
}
|
160
140
|
|
161
|
-
void IndexPreTransform::add
|
162
|
-
|
163
|
-
|
164
|
-
const float *xt = apply_chain (n, x);
|
141
|
+
void IndexPreTransform::add(idx_t n, const float* x) {
|
142
|
+
FAISS_THROW_IF_NOT(is_trained);
|
143
|
+
const float* xt = apply_chain(n, x);
|
165
144
|
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
166
|
-
index->add
|
145
|
+
index->add(n, xt);
|
167
146
|
ntotal = index->ntotal;
|
168
147
|
}
|
169
148
|
|
170
|
-
void IndexPreTransform::add_with_ids
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
149
|
+
void IndexPreTransform::add_with_ids(
|
150
|
+
idx_t n,
|
151
|
+
const float* x,
|
152
|
+
const idx_t* xids) {
|
153
|
+
FAISS_THROW_IF_NOT(is_trained);
|
154
|
+
const float* xt = apply_chain(n, x);
|
175
155
|
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
176
|
-
index->add_with_ids
|
156
|
+
index->add_with_ids(n, xt, xids);
|
177
157
|
ntotal = index->ntotal;
|
178
158
|
}
|
179
159
|
|
160
|
+
void IndexPreTransform::search(
|
161
|
+
idx_t n,
|
162
|
+
const float* x,
|
163
|
+
idx_t k,
|
164
|
+
float* distances,
|
165
|
+
idx_t* labels) const {
|
166
|
+
FAISS_THROW_IF_NOT(k > 0);
|
180
167
|
|
181
|
-
|
182
|
-
|
183
|
-
void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
|
184
|
-
float *distances, idx_t *labels) const
|
185
|
-
{
|
186
|
-
FAISS_THROW_IF_NOT (is_trained);
|
187
|
-
const float *xt = apply_chain (n, x);
|
168
|
+
FAISS_THROW_IF_NOT(is_trained);
|
169
|
+
const float* xt = apply_chain(n, x);
|
188
170
|
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
189
|
-
index->search
|
171
|
+
index->search(n, xt, k, distances, labels);
|
190
172
|
}
|
191
173
|
|
192
|
-
void IndexPreTransform::range_search
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
174
|
+
void IndexPreTransform::range_search(
|
175
|
+
idx_t n,
|
176
|
+
const float* x,
|
177
|
+
float radius,
|
178
|
+
RangeSearchResult* result) const {
|
179
|
+
FAISS_THROW_IF_NOT(is_trained);
|
180
|
+
const float* xt = apply_chain(n, x);
|
197
181
|
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
198
|
-
index->range_search
|
182
|
+
index->range_search(n, xt, radius, result);
|
199
183
|
}
|
200
184
|
|
201
|
-
|
202
|
-
|
203
|
-
void IndexPreTransform::reset () {
|
185
|
+
void IndexPreTransform::reset() {
|
204
186
|
index->reset();
|
205
187
|
ntotal = 0;
|
206
188
|
}
|
207
189
|
|
208
|
-
size_t IndexPreTransform::remove_ids
|
209
|
-
size_t nremove = index->remove_ids
|
190
|
+
size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
|
191
|
+
size_t nremove = index->remove_ids(sel);
|
210
192
|
ntotal = index->ntotal;
|
211
193
|
return nremove;
|
212
194
|
}
|
213
195
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
float *x = chain.empty() ? recons : new float [index->d];
|
218
|
-
ScopeDeleter<float> del (recons == x ? nullptr : x);
|
196
|
+
void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
|
197
|
+
float* x = chain.empty() ? recons : new float[index->d];
|
198
|
+
ScopeDeleter<float> del(recons == x ? nullptr : x);
|
219
199
|
// Initial reconstruction
|
220
|
-
index->reconstruct
|
200
|
+
index->reconstruct(key, x);
|
221
201
|
|
222
202
|
// Revert transformations from last to first
|
223
|
-
reverse_chain
|
203
|
+
reverse_chain(1, x, recons);
|
224
204
|
}
|
225
205
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
float *x = chain.empty() ? recons : new float [ni * index->d];
|
230
|
-
ScopeDeleter<float> del (recons == x ? nullptr : x);
|
206
|
+
void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
207
|
+
float* x = chain.empty() ? recons : new float[ni * index->d];
|
208
|
+
ScopeDeleter<float> del(recons == x ? nullptr : x);
|
231
209
|
// Initial reconstruction
|
232
|
-
index->reconstruct_n
|
210
|
+
index->reconstruct_n(i0, ni, x);
|
233
211
|
|
234
212
|
// Revert transformations from last to first
|
235
|
-
reverse_chain
|
213
|
+
reverse_chain(ni, x, recons);
|
236
214
|
}
|
237
215
|
|
216
|
+
void IndexPreTransform::search_and_reconstruct(
|
217
|
+
idx_t n,
|
218
|
+
const float* x,
|
219
|
+
idx_t k,
|
220
|
+
float* distances,
|
221
|
+
idx_t* labels,
|
222
|
+
float* recons) const {
|
223
|
+
FAISS_THROW_IF_NOT(k > 0);
|
238
224
|
|
239
|
-
|
240
|
-
idx_t n, const float *x, idx_t k,
|
241
|
-
float *distances, idx_t *labels, float* recons) const
|
242
|
-
{
|
243
|
-
FAISS_THROW_IF_NOT (is_trained);
|
225
|
+
FAISS_THROW_IF_NOT(is_trained);
|
244
226
|
|
245
|
-
const float* xt = apply_chain
|
246
|
-
ScopeDeleter<float> del
|
227
|
+
const float* xt = apply_chain(n, x);
|
228
|
+
ScopeDeleter<float> del((xt == x) ? nullptr : xt);
|
247
229
|
|
248
|
-
float* recons_temp = chain.empty() ? recons : new float
|
249
|
-
ScopeDeleter<float> del2
|
250
|
-
index->search_and_reconstruct
|
230
|
+
float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
|
231
|
+
ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
|
232
|
+
index->search_and_reconstruct(n, xt, k, distances, labels, recons_temp);
|
251
233
|
|
252
234
|
// Revert transformations from last to first
|
253
|
-
reverse_chain
|
235
|
+
reverse_chain(n * k, recons_temp, recons);
|
254
236
|
}
|
255
237
|
|
256
|
-
size_t IndexPreTransform::sa_code_size
|
257
|
-
|
258
|
-
return index->sa_code_size ();
|
238
|
+
size_t IndexPreTransform::sa_code_size() const {
|
239
|
+
return index->sa_code_size();
|
259
240
|
}
|
260
241
|
|
261
|
-
void IndexPreTransform::sa_encode
|
262
|
-
|
263
|
-
{
|
242
|
+
void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
243
|
+
const {
|
264
244
|
if (chain.empty()) {
|
265
|
-
index->sa_encode
|
245
|
+
index->sa_encode(n, x, bytes);
|
266
246
|
} else {
|
267
|
-
const float
|
247
|
+
const float* xt = apply_chain(n, x);
|
268
248
|
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
269
|
-
index->sa_encode
|
249
|
+
index->sa_encode(n, xt, bytes);
|
270
250
|
}
|
271
251
|
}
|
272
252
|
|
273
|
-
void IndexPreTransform::sa_decode
|
274
|
-
|
275
|
-
{
|
253
|
+
void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
254
|
+
const {
|
276
255
|
if (chain.empty()) {
|
277
|
-
index->sa_decode
|
256
|
+
index->sa_decode(n, bytes, x);
|
278
257
|
} else {
|
279
|
-
std::unique_ptr<float
|
280
|
-
index->sa_decode
|
258
|
+
std::unique_ptr<float[]> x1(new float[index->d * n]);
|
259
|
+
index->sa_decode(n, bytes, x1.get());
|
281
260
|
// Revert transformations from last to first
|
282
|
-
reverse_chain
|
261
|
+
reverse_chain(n, x1.get(), x);
|
283
262
|
}
|
284
263
|
}
|
285
264
|
|
286
265
|
namespace {
|
287
266
|
|
288
|
-
struct PreTransformDistanceComputer: DistanceComputer {
|
289
|
-
const IndexPreTransform
|
267
|
+
struct PreTransformDistanceComputer : DistanceComputer {
|
268
|
+
const IndexPreTransform* index;
|
290
269
|
std::unique_ptr<DistanceComputer> sub_dc;
|
291
|
-
std::unique_ptr<const float
|
270
|
+
std::unique_ptr<const float[]> query;
|
292
271
|
|
293
|
-
explicit PreTransformDistanceComputer(const IndexPreTransform
|
294
|
-
|
295
|
-
sub_dc(index->index->get_distance_computer())
|
296
|
-
{}
|
272
|
+
explicit PreTransformDistanceComputer(const IndexPreTransform* index)
|
273
|
+
: index(index), sub_dc(index->index->get_distance_computer()) {}
|
297
274
|
|
298
|
-
void set_query(const float
|
299
|
-
const float
|
275
|
+
void set_query(const float* x) override {
|
276
|
+
const float* xt = index->apply_chain(1, x);
|
300
277
|
if (xt == x) {
|
301
|
-
sub_dc->set_query
|
278
|
+
sub_dc->set_query(x);
|
302
279
|
} else {
|
303
280
|
query.reset(xt);
|
304
|
-
sub_dc->set_query
|
281
|
+
sub_dc->set_query(xt);
|
305
282
|
}
|
306
283
|
}
|
307
284
|
|
308
|
-
float symmetric_dis(idx_t i, idx_t j) override
|
309
|
-
{
|
285
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
310
286
|
return sub_dc->symmetric_dis(i, j);
|
311
287
|
}
|
312
288
|
|
313
|
-
float operator
|
314
|
-
{
|
289
|
+
float operator()(idx_t i) override {
|
315
290
|
return (*sub_dc)(i);
|
316
291
|
}
|
317
|
-
|
318
292
|
};
|
319
293
|
|
320
|
-
|
321
294
|
} // anonymous namespace
|
322
295
|
|
323
|
-
|
324
|
-
DistanceComputer * IndexPreTransform::get_distance_computer() const {
|
296
|
+
DistanceComputer* IndexPreTransform::get_distance_computer() const {
|
325
297
|
if (chain.empty()) {
|
326
298
|
return index->get_distance_computer();
|
327
299
|
} else {
|
@@ -329,6 +301,4 @@ DistanceComputer * IndexPreTransform::get_distance_computer() const {
|
|
329
301
|
}
|
330
302
|
}
|
331
303
|
|
332
|
-
|
333
|
-
|
334
304
|
} // namespace faiss
|