faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -67,7 +67,7 @@ void IndexPreTransform::train(idx_t n, const float* x) {
|
|
|
67
67
|
}
|
|
68
68
|
}
|
|
69
69
|
const float* prev_x = x;
|
|
70
|
-
|
|
70
|
+
std::unique_ptr<const float[]> del;
|
|
71
71
|
|
|
72
72
|
if (verbose) {
|
|
73
73
|
printf("IndexPreTransform::train: training chain 0 to %d\n",
|
|
@@ -102,10 +102,12 @@ void IndexPreTransform::train(idx_t n, const float* x) {
|
|
|
102
102
|
|
|
103
103
|
float* xt = chain[i]->apply(n, prev_x);
|
|
104
104
|
|
|
105
|
-
if (prev_x != x)
|
|
106
|
-
|
|
105
|
+
if (prev_x != x) {
|
|
106
|
+
del.reset();
|
|
107
|
+
}
|
|
108
|
+
|
|
107
109
|
prev_x = xt;
|
|
108
|
-
del.
|
|
110
|
+
del.reset(xt);
|
|
109
111
|
}
|
|
110
112
|
|
|
111
113
|
is_trained = true;
|
|
@@ -113,11 +115,11 @@ void IndexPreTransform::train(idx_t n, const float* x) {
|
|
|
113
115
|
|
|
114
116
|
const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
|
|
115
117
|
const float* prev_x = x;
|
|
116
|
-
|
|
118
|
+
std::unique_ptr<const float[]> del;
|
|
117
119
|
|
|
118
120
|
for (int i = 0; i < chain.size(); i++) {
|
|
119
121
|
float* xt = chain[i]->apply(n, prev_x);
|
|
120
|
-
|
|
122
|
+
std::unique_ptr<const float[]> del2(xt);
|
|
121
123
|
del2.swap(del);
|
|
122
124
|
prev_x = xt;
|
|
123
125
|
}
|
|
@@ -128,11 +130,11 @@ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
|
|
|
128
130
|
void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
|
|
129
131
|
const {
|
|
130
132
|
const float* next_x = xt;
|
|
131
|
-
|
|
133
|
+
std::unique_ptr<const float[]> del;
|
|
132
134
|
|
|
133
135
|
for (int i = chain.size() - 1; i >= 0; i--) {
|
|
134
136
|
float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
|
|
135
|
-
|
|
137
|
+
std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x);
|
|
136
138
|
chain[i]->reverse_transform(n, next_x, prev_x);
|
|
137
139
|
del2.swap(del);
|
|
138
140
|
next_x = prev_x;
|
|
@@ -141,9 +143,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
|
|
|
141
143
|
|
|
142
144
|
void IndexPreTransform::add(idx_t n, const float* x) {
|
|
143
145
|
FAISS_THROW_IF_NOT(is_trained);
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
index->add(n, xt);
|
|
146
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
|
147
|
+
index->add(n, tv.x);
|
|
147
148
|
ntotal = index->ntotal;
|
|
148
149
|
}
|
|
149
150
|
|
|
@@ -152,9 +153,8 @@ void IndexPreTransform::add_with_ids(
|
|
|
152
153
|
const float* x,
|
|
153
154
|
const idx_t* xids) {
|
|
154
155
|
FAISS_THROW_IF_NOT(is_trained);
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
index->add_with_ids(n, xt, xids);
|
|
156
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
|
157
|
+
index->add_with_ids(n, tv.x, xids);
|
|
158
158
|
ntotal = index->ntotal;
|
|
159
159
|
}
|
|
160
160
|
|
|
@@ -178,7 +178,7 @@ void IndexPreTransform::search(
|
|
|
178
178
|
FAISS_THROW_IF_NOT(k > 0);
|
|
179
179
|
FAISS_THROW_IF_NOT(is_trained);
|
|
180
180
|
const float* xt = apply_chain(n, x);
|
|
181
|
-
|
|
181
|
+
std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
|
|
182
182
|
index->search(
|
|
183
183
|
n, xt, k, distances, labels, extract_index_search_params(params));
|
|
184
184
|
}
|
|
@@ -190,10 +190,9 @@ void IndexPreTransform::range_search(
|
|
|
190
190
|
RangeSearchResult* result,
|
|
191
191
|
const SearchParameters* params) const {
|
|
192
192
|
FAISS_THROW_IF_NOT(is_trained);
|
|
193
|
-
|
|
194
|
-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
193
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
|
195
194
|
index->range_search(
|
|
196
|
-
n,
|
|
195
|
+
n, tv.x, radius, result, extract_index_search_params(params));
|
|
197
196
|
}
|
|
198
197
|
|
|
199
198
|
void IndexPreTransform::reset() {
|
|
@@ -209,7 +208,7 @@ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
|
|
|
209
208
|
|
|
210
209
|
void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
|
|
211
210
|
float* x = chain.empty() ? recons : new float[index->d];
|
|
212
|
-
|
|
211
|
+
std::unique_ptr<float[]> del(recons == x ? nullptr : x);
|
|
213
212
|
// Initial reconstruction
|
|
214
213
|
index->reconstruct(key, x);
|
|
215
214
|
|
|
@@ -219,7 +218,7 @@ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
|
|
|
219
218
|
|
|
220
219
|
void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
|
221
220
|
float* x = chain.empty() ? recons : new float[ni * index->d];
|
|
222
|
-
|
|
221
|
+
std::unique_ptr<float[]> del(recons == x ? nullptr : x);
|
|
223
222
|
// Initial reconstruction
|
|
224
223
|
index->reconstruct_n(i0, ni, x);
|
|
225
224
|
|
|
@@ -238,14 +237,14 @@ void IndexPreTransform::search_and_reconstruct(
|
|
|
238
237
|
FAISS_THROW_IF_NOT(k > 0);
|
|
239
238
|
FAISS_THROW_IF_NOT(is_trained);
|
|
240
239
|
|
|
241
|
-
|
|
242
|
-
ScopeDeleter<float> del((xt == x) ? nullptr : xt);
|
|
240
|
+
TransformedVectors trans(x, apply_chain(n, x));
|
|
243
241
|
|
|
244
242
|
float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
|
|
245
|
-
|
|
243
|
+
std::unique_ptr<float[]> del2(
|
|
244
|
+
(recons_temp == recons) ? nullptr : recons_temp);
|
|
246
245
|
index->search_and_reconstruct(
|
|
247
246
|
n,
|
|
248
|
-
|
|
247
|
+
trans.x,
|
|
249
248
|
k,
|
|
250
249
|
distances,
|
|
251
250
|
labels,
|
|
@@ -262,13 +261,8 @@ size_t IndexPreTransform::sa_code_size() const {
|
|
|
262
261
|
|
|
263
262
|
void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
|
264
263
|
const {
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
} else {
|
|
268
|
-
const float* xt = apply_chain(n, x);
|
|
269
|
-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
270
|
-
index->sa_encode(n, xt, bytes);
|
|
271
|
-
}
|
|
264
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
|
265
|
+
index->sa_encode(n, tv.x, bytes);
|
|
272
266
|
}
|
|
273
267
|
|
|
274
268
|
void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
@@ -23,7 +23,7 @@ struct SearchParametersPreTransform : SearchParameters {
|
|
|
23
23
|
/** Index that applies a LinearTransform transform on vectors before
|
|
24
24
|
* handing them over to a sub-index */
|
|
25
25
|
struct IndexPreTransform : Index {
|
|
26
|
-
std::vector<VectorTransform*> chain; ///! chain of
|
|
26
|
+
std::vector<VectorTransform*> chain; ///! chain of transforms
|
|
27
27
|
Index* index; ///! the sub-index
|
|
28
28
|
|
|
29
29
|
bool own_fields; ///! whether pointers are deleted in destructor
|
|
@@ -62,18 +62,18 @@ void IndexRefine::reset() {
|
|
|
62
62
|
|
|
63
63
|
namespace {
|
|
64
64
|
|
|
65
|
-
|
|
65
|
+
using idx_t = faiss::idx_t;
|
|
66
66
|
|
|
67
67
|
template <class C>
|
|
68
68
|
static void reorder_2_heaps(
|
|
69
69
|
idx_t n,
|
|
70
70
|
idx_t k,
|
|
71
|
-
idx_t* labels,
|
|
72
|
-
float* distances,
|
|
71
|
+
idx_t* __restrict labels,
|
|
72
|
+
float* __restrict distances,
|
|
73
73
|
idx_t k_base,
|
|
74
|
-
const idx_t* base_labels,
|
|
75
|
-
const float* base_distances) {
|
|
76
|
-
#pragma omp parallel for
|
|
74
|
+
const idx_t* __restrict base_labels,
|
|
75
|
+
const float* __restrict base_distances) {
|
|
76
|
+
#pragma omp parallel for if (n > 1)
|
|
77
77
|
for (idx_t i = 0; i < n; i++) {
|
|
78
78
|
idx_t* idxo = labels + i * k;
|
|
79
79
|
float* diso = distances + i * k;
|
|
@@ -96,25 +96,40 @@ void IndexRefine::search(
|
|
|
96
96
|
idx_t k,
|
|
97
97
|
float* distances,
|
|
98
98
|
idx_t* labels,
|
|
99
|
-
const SearchParameters*
|
|
100
|
-
|
|
101
|
-
|
|
99
|
+
const SearchParameters* params_in) const {
|
|
100
|
+
const IndexRefineSearchParameters* params = nullptr;
|
|
101
|
+
if (params_in) {
|
|
102
|
+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
|
|
103
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
104
|
+
params, "IndexRefine params have incorrect type");
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
|
|
108
|
+
: idx_t(k * k_factor);
|
|
109
|
+
SearchParameters* base_index_params =
|
|
110
|
+
(params != nullptr) ? params->base_index_params : nullptr;
|
|
111
|
+
|
|
112
|
+
FAISS_THROW_IF_NOT(k_base >= k);
|
|
113
|
+
|
|
114
|
+
FAISS_THROW_IF_NOT(base_index);
|
|
115
|
+
FAISS_THROW_IF_NOT(refine_index);
|
|
116
|
+
|
|
102
117
|
FAISS_THROW_IF_NOT(k > 0);
|
|
103
118
|
FAISS_THROW_IF_NOT(is_trained);
|
|
104
|
-
idx_t k_base = idx_t(k * k_factor);
|
|
105
119
|
idx_t* base_labels = labels;
|
|
106
120
|
float* base_distances = distances;
|
|
107
|
-
|
|
108
|
-
|
|
121
|
+
std::unique_ptr<idx_t[]> del1;
|
|
122
|
+
std::unique_ptr<float[]> del2;
|
|
109
123
|
|
|
110
124
|
if (k != k_base) {
|
|
111
125
|
base_labels = new idx_t[n * k_base];
|
|
112
|
-
del1.
|
|
126
|
+
del1.reset(base_labels);
|
|
113
127
|
base_distances = new float[n * k_base];
|
|
114
|
-
del2.
|
|
128
|
+
del2.reset(base_distances);
|
|
115
129
|
}
|
|
116
130
|
|
|
117
|
-
base_index->search(
|
|
131
|
+
base_index->search(
|
|
132
|
+
n, x, k_base, base_distances, base_labels, base_index_params);
|
|
118
133
|
|
|
119
134
|
for (int i = 0; i < n * k_base; i++)
|
|
120
135
|
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
@@ -225,25 +240,40 @@ void IndexRefineFlat::search(
|
|
|
225
240
|
idx_t k,
|
|
226
241
|
float* distances,
|
|
227
242
|
idx_t* labels,
|
|
228
|
-
const SearchParameters*
|
|
229
|
-
|
|
230
|
-
|
|
243
|
+
const SearchParameters* params_in) const {
|
|
244
|
+
const IndexRefineSearchParameters* params = nullptr;
|
|
245
|
+
if (params_in) {
|
|
246
|
+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
|
|
247
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
248
|
+
params, "IndexRefineFlat params have incorrect type");
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
|
|
252
|
+
: idx_t(k * k_factor);
|
|
253
|
+
SearchParameters* base_index_params =
|
|
254
|
+
(params != nullptr) ? params->base_index_params : nullptr;
|
|
255
|
+
|
|
256
|
+
FAISS_THROW_IF_NOT(k_base >= k);
|
|
257
|
+
|
|
258
|
+
FAISS_THROW_IF_NOT(base_index);
|
|
259
|
+
FAISS_THROW_IF_NOT(refine_index);
|
|
260
|
+
|
|
231
261
|
FAISS_THROW_IF_NOT(k > 0);
|
|
232
262
|
FAISS_THROW_IF_NOT(is_trained);
|
|
233
|
-
idx_t k_base = idx_t(k * k_factor);
|
|
234
263
|
idx_t* base_labels = labels;
|
|
235
264
|
float* base_distances = distances;
|
|
236
|
-
|
|
237
|
-
|
|
265
|
+
std::unique_ptr<idx_t[]> del1;
|
|
266
|
+
std::unique_ptr<float[]> del2;
|
|
238
267
|
|
|
239
268
|
if (k != k_base) {
|
|
240
269
|
base_labels = new idx_t[n * k_base];
|
|
241
|
-
del1.
|
|
270
|
+
del1.reset(base_labels);
|
|
242
271
|
base_distances = new float[n * k_base];
|
|
243
|
-
del2.
|
|
272
|
+
del2.reset(base_distances);
|
|
244
273
|
}
|
|
245
274
|
|
|
246
|
-
base_index->search(
|
|
275
|
+
base_index->search(
|
|
276
|
+
n, x, k_base, base_distances, base_labels, base_index_params);
|
|
247
277
|
|
|
248
278
|
for (int i = 0; i < n * k_base; i++)
|
|
249
279
|
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
@@ -11,6 +11,13 @@
|
|
|
11
11
|
|
|
12
12
|
namespace faiss {
|
|
13
13
|
|
|
14
|
+
struct IndexRefineSearchParameters : SearchParameters {
|
|
15
|
+
float k_factor = 1;
|
|
16
|
+
SearchParameters* base_index_params = nullptr; // non-owning
|
|
17
|
+
|
|
18
|
+
virtual ~IndexRefineSearchParameters() = default;
|
|
19
|
+
};
|
|
20
|
+
|
|
14
21
|
/** Index that queries in a base_index (a fast one) and refines the
|
|
15
22
|
* results with an exact search, hopefully improving the results.
|
|
16
23
|
*/
|
|
@@ -12,17 +12,34 @@
|
|
|
12
12
|
|
|
13
13
|
namespace faiss {
|
|
14
14
|
|
|
15
|
+
namespace {
|
|
16
|
+
|
|
17
|
+
// IndexBinary needs to update the code_size when d is set...
|
|
18
|
+
|
|
19
|
+
void sync_d(Index* index) {}
|
|
20
|
+
|
|
21
|
+
void sync_d(IndexBinary* index) {
|
|
22
|
+
FAISS_THROW_IF_NOT(index->d % 8 == 0);
|
|
23
|
+
index->code_size = index->d / 8;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
} // anonymous namespace
|
|
27
|
+
|
|
15
28
|
template <typename IndexT>
|
|
16
29
|
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
|
|
17
30
|
: ThreadedIndex<IndexT>(threaded) {}
|
|
18
31
|
|
|
19
32
|
template <typename IndexT>
|
|
20
33
|
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d, bool threaded)
|
|
21
|
-
: ThreadedIndex<IndexT>(d, threaded) {
|
|
34
|
+
: ThreadedIndex<IndexT>(d, threaded) {
|
|
35
|
+
sync_d(this);
|
|
36
|
+
}
|
|
22
37
|
|
|
23
38
|
template <typename IndexT>
|
|
24
39
|
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d, bool threaded)
|
|
25
|
-
: ThreadedIndex<IndexT>(d, threaded) {
|
|
40
|
+
: ThreadedIndex<IndexT>(d, threaded) {
|
|
41
|
+
sync_d(this);
|
|
42
|
+
}
|
|
26
43
|
|
|
27
44
|
template <typename IndexT>
|
|
28
45
|
void IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
|
|
@@ -168,6 +185,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
|
|
|
168
185
|
}
|
|
169
186
|
|
|
170
187
|
auto firstIndex = this->at(0);
|
|
188
|
+
this->d = firstIndex->d;
|
|
189
|
+
sync_d(this);
|
|
171
190
|
this->metric_type = firstIndex->metric_type;
|
|
172
191
|
this->is_trained = firstIndex->is_trained;
|
|
173
192
|
this->ntotal = firstIndex->ntotal;
|
|
@@ -181,30 +200,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
|
|
|
181
200
|
}
|
|
182
201
|
}
|
|
183
202
|
|
|
184
|
-
// No metric_type for IndexBinary
|
|
185
|
-
template <>
|
|
186
|
-
void IndexReplicasTemplate<IndexBinary>::syncWithSubIndexes() {
|
|
187
|
-
if (!this->count()) {
|
|
188
|
-
this->is_trained = false;
|
|
189
|
-
this->ntotal = 0;
|
|
190
|
-
|
|
191
|
-
return;
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
auto firstIndex = this->at(0);
|
|
195
|
-
this->is_trained = firstIndex->is_trained;
|
|
196
|
-
this->ntotal = firstIndex->ntotal;
|
|
197
|
-
|
|
198
|
-
for (int i = 1; i < this->count(); ++i) {
|
|
199
|
-
auto index = this->at(i);
|
|
200
|
-
FAISS_THROW_IF_NOT(this->d == index->d);
|
|
201
|
-
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
|
|
202
|
-
FAISS_THROW_IF_NOT(this->ntotal == index->ntotal);
|
|
203
|
-
}
|
|
204
|
-
}
|
|
205
|
-
|
|
206
203
|
// explicit instantiations
|
|
207
|
-
template
|
|
208
|
-
template
|
|
204
|
+
template class IndexReplicasTemplate<Index>;
|
|
205
|
+
template class IndexReplicasTemplate<IndexBinary>;
|
|
209
206
|
|
|
210
207
|
} // namespace faiss
|
|
@@ -32,7 +32,9 @@ IndexScalarQuantizer::IndexScalarQuantizer(
|
|
|
32
32
|
MetricType metric)
|
|
33
33
|
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
|
|
34
34
|
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
|
|
35
|
-
qtype == ScalarQuantizer::QT_8bit_direct
|
|
35
|
+
qtype == ScalarQuantizer::QT_8bit_direct ||
|
|
36
|
+
qtype == ScalarQuantizer::QT_bf16 ||
|
|
37
|
+
qtype == ScalarQuantizer::QT_8bit_direct_signed;
|
|
36
38
|
code_size = sq.code_size;
|
|
37
39
|
}
|
|
38
40
|
|
|
@@ -60,10 +62,9 @@ void IndexScalarQuantizer::search(
|
|
|
60
62
|
|
|
61
63
|
#pragma omp parallel
|
|
62
64
|
{
|
|
63
|
-
InvertedListScanner
|
|
64
|
-
sq.select_InvertedListScanner(metric_type, nullptr, true, sel);
|
|
65
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
|
66
|
+
sq.select_InvertedListScanner(metric_type, nullptr, true, sel));
|
|
65
67
|
|
|
66
|
-
ScopeDeleter1<InvertedListScanner> del(scanner);
|
|
67
68
|
scanner->list_no = 0; // directly the list number
|
|
68
69
|
|
|
69
70
|
#pragma omp for
|
|
@@ -122,21 +123,28 @@ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
|
|
|
122
123
|
size_t nlist,
|
|
123
124
|
ScalarQuantizer::QuantizerType qtype,
|
|
124
125
|
MetricType metric,
|
|
125
|
-
bool
|
|
126
|
-
: IndexIVF(quantizer, d, nlist, 0, metric),
|
|
127
|
-
sq(d, qtype),
|
|
128
|
-
by_residual(encode_residual) {
|
|
126
|
+
bool by_residual)
|
|
127
|
+
: IndexIVF(quantizer, d, nlist, 0, metric), sq(d, qtype) {
|
|
129
128
|
code_size = sq.code_size;
|
|
129
|
+
this->by_residual = by_residual;
|
|
130
130
|
// was not known at construction time
|
|
131
131
|
invlists->code_size = code_size;
|
|
132
132
|
is_trained = false;
|
|
133
133
|
}
|
|
134
134
|
|
|
135
|
-
IndexIVFScalarQuantizer::IndexIVFScalarQuantizer()
|
|
136
|
-
|
|
135
|
+
IndexIVFScalarQuantizer::IndexIVFScalarQuantizer() : IndexIVF() {
|
|
136
|
+
by_residual = true;
|
|
137
|
+
}
|
|
137
138
|
|
|
138
|
-
void IndexIVFScalarQuantizer::
|
|
139
|
-
|
|
139
|
+
void IndexIVFScalarQuantizer::train_encoder(
|
|
140
|
+
idx_t n,
|
|
141
|
+
const float* x,
|
|
142
|
+
const idx_t* assign) {
|
|
143
|
+
sq.train(n, x);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
idx_t IndexIVFScalarQuantizer::train_encoder_num_vectors() const {
|
|
147
|
+
return 100000;
|
|
140
148
|
}
|
|
141
149
|
|
|
142
150
|
void IndexIVFScalarQuantizer::encode_vectors(
|
|
@@ -201,15 +209,15 @@ void IndexIVFScalarQuantizer::add_core(
|
|
|
201
209
|
idx_t n,
|
|
202
210
|
const float* x,
|
|
203
211
|
const idx_t* xids,
|
|
204
|
-
const idx_t* coarse_idx
|
|
212
|
+
const idx_t* coarse_idx,
|
|
213
|
+
void* inverted_list_context) {
|
|
205
214
|
FAISS_THROW_IF_NOT(is_trained);
|
|
206
215
|
|
|
207
|
-
size_t nadd = 0;
|
|
208
216
|
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
|
|
209
217
|
|
|
210
218
|
DirectMapAdd dm_add(direct_map, n, xids);
|
|
211
219
|
|
|
212
|
-
#pragma omp parallel
|
|
220
|
+
#pragma omp parallel
|
|
213
221
|
{
|
|
214
222
|
std::vector<float> residual(d);
|
|
215
223
|
std::vector<uint8_t> one_code(code_size);
|
|
@@ -231,10 +239,10 @@ void IndexIVFScalarQuantizer::add_core(
|
|
|
231
239
|
memset(one_code.data(), 0, code_size);
|
|
232
240
|
squant->encode_vector(xi, one_code.data());
|
|
233
241
|
|
|
234
|
-
size_t ofs = invlists->add_entry(
|
|
242
|
+
size_t ofs = invlists->add_entry(
|
|
243
|
+
list_no, id, one_code.data(), inverted_list_context);
|
|
235
244
|
|
|
236
245
|
dm_add.add(i, list_no, ofs);
|
|
237
|
-
nadd++;
|
|
238
246
|
|
|
239
247
|
} else if (rank == 0 && list_no == -1) {
|
|
240
248
|
dm_add.add(i, -1, 0);
|
|
@@ -65,7 +65,6 @@ struct IndexScalarQuantizer : IndexFlatCodes {
|
|
|
65
65
|
|
|
66
66
|
struct IndexIVFScalarQuantizer : IndexIVF {
|
|
67
67
|
ScalarQuantizer sq;
|
|
68
|
-
bool by_residual;
|
|
69
68
|
|
|
70
69
|
IndexIVFScalarQuantizer(
|
|
71
70
|
Index* quantizer,
|
|
@@ -73,11 +72,13 @@ struct IndexIVFScalarQuantizer : IndexIVF {
|
|
|
73
72
|
size_t nlist,
|
|
74
73
|
ScalarQuantizer::QuantizerType qtype,
|
|
75
74
|
MetricType metric = METRIC_L2,
|
|
76
|
-
bool
|
|
75
|
+
bool by_residual = true);
|
|
77
76
|
|
|
78
77
|
IndexIVFScalarQuantizer();
|
|
79
78
|
|
|
80
|
-
void
|
|
79
|
+
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
|
|
80
|
+
|
|
81
|
+
idx_t train_encoder_num_vectors() const override;
|
|
81
82
|
|
|
82
83
|
void encode_vectors(
|
|
83
84
|
idx_t n,
|
|
@@ -90,7 +91,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
|
|
|
90
91
|
idx_t n,
|
|
91
92
|
const float* x,
|
|
92
93
|
const idx_t* xids,
|
|
93
|
-
const idx_t* precomputed_idx
|
|
94
|
+
const idx_t* precomputed_idx,
|
|
95
|
+
void* inverted_list_context = nullptr) override;
|
|
94
96
|
|
|
95
97
|
InvertedListScanner* get_InvertedListScanner(
|
|
96
98
|
bool store_pairs,
|
|
@@ -5,8 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#include <faiss/IndexShards.h>
|
|
11
9
|
|
|
12
10
|
#include <cinttypes>
|
|
@@ -22,6 +20,15 @@ namespace faiss {
|
|
|
22
20
|
// subroutines
|
|
23
21
|
namespace {
|
|
24
22
|
|
|
23
|
+
// IndexBinary needs to update the code_size when d is set...
|
|
24
|
+
|
|
25
|
+
void sync_d(Index* index) {}
|
|
26
|
+
|
|
27
|
+
void sync_d(IndexBinary* index) {
|
|
28
|
+
FAISS_THROW_IF_NOT(index->d % 8 == 0);
|
|
29
|
+
index->code_size = index->d / 8;
|
|
30
|
+
}
|
|
31
|
+
|
|
25
32
|
// add translation to all valid labels
|
|
26
33
|
void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
|
|
27
34
|
if (translation == 0)
|
|
@@ -40,20 +47,26 @@ IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
|
40
47
|
idx_t d,
|
|
41
48
|
bool threaded,
|
|
42
49
|
bool successive_ids)
|
|
43
|
-
: ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
|
|
50
|
+
: ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
|
|
51
|
+
sync_d(this);
|
|
52
|
+
}
|
|
44
53
|
|
|
45
54
|
template <typename IndexT>
|
|
46
55
|
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
47
56
|
int d,
|
|
48
57
|
bool threaded,
|
|
49
58
|
bool successive_ids)
|
|
50
|
-
: ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
|
|
59
|
+
: ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {
|
|
60
|
+
sync_d(this);
|
|
61
|
+
}
|
|
51
62
|
|
|
52
63
|
template <typename IndexT>
|
|
53
64
|
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
54
65
|
bool threaded,
|
|
55
66
|
bool successive_ids)
|
|
56
|
-
: ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {
|
|
67
|
+
: ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {
|
|
68
|
+
sync_d(this);
|
|
69
|
+
}
|
|
57
70
|
|
|
58
71
|
template <typename IndexT>
|
|
59
72
|
void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
|
|
@@ -78,6 +91,8 @@ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
|
|
|
78
91
|
}
|
|
79
92
|
|
|
80
93
|
auto firstIndex = this->at(0);
|
|
94
|
+
this->d = firstIndex->d;
|
|
95
|
+
sync_d(this);
|
|
81
96
|
this->metric_type = firstIndex->metric_type;
|
|
82
97
|
this->is_trained = firstIndex->is_trained;
|
|
83
98
|
this->ntotal = firstIndex->ntotal;
|
|
@@ -92,29 +107,6 @@ void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
|
|
|
92
107
|
}
|
|
93
108
|
}
|
|
94
109
|
|
|
95
|
-
// No metric_type for IndexBinary
|
|
96
|
-
template <>
|
|
97
|
-
void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
|
|
98
|
-
if (!this->count()) {
|
|
99
|
-
this->is_trained = false;
|
|
100
|
-
this->ntotal = 0;
|
|
101
|
-
|
|
102
|
-
return;
|
|
103
|
-
}
|
|
104
|
-
|
|
105
|
-
auto firstIndex = this->at(0);
|
|
106
|
-
this->is_trained = firstIndex->is_trained;
|
|
107
|
-
this->ntotal = firstIndex->ntotal;
|
|
108
|
-
|
|
109
|
-
for (int i = 1; i < this->count(); ++i) {
|
|
110
|
-
auto index = this->at(i);
|
|
111
|
-
FAISS_THROW_IF_NOT(this->d == index->d);
|
|
112
|
-
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
|
|
113
|
-
|
|
114
|
-
this->ntotal += index->ntotal;
|
|
115
|
-
}
|
|
116
|
-
}
|
|
117
|
-
|
|
118
110
|
template <typename IndexT>
|
|
119
111
|
void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
|
|
120
112
|
auto fn = [n, x](int no, IndexT* index) {
|
|
@@ -155,7 +147,7 @@ void IndexShardsTemplate<IndexT>::add_with_ids(
|
|
|
155
147
|
"request them to be shifted");
|
|
156
148
|
FAISS_THROW_IF_NOT_MSG(
|
|
157
149
|
this->ntotal == 0,
|
|
158
|
-
"when adding to IndexShards with
|
|
150
|
+
"when adding to IndexShards with successive_ids, "
|
|
159
151
|
"only add() in a single pass is supported");
|
|
160
152
|
}
|
|
161
153
|
|
|
@@ -111,7 +111,7 @@ void IndexShardsIVF::add_with_ids(
|
|
|
111
111
|
"request them to be shifted");
|
|
112
112
|
FAISS_THROW_IF_NOT_MSG(
|
|
113
113
|
this->ntotal == 0,
|
|
114
|
-
"when adding to IndexShards with
|
|
114
|
+
"when adding to IndexShards with successive_ids, "
|
|
115
115
|
"only add() in a single pass is supported");
|
|
116
116
|
}
|
|
117
117
|
|
|
@@ -137,7 +137,6 @@ void IndexShardsIVF::add_with_ids(
|
|
|
137
137
|
auto fn = [n, ids, x, nshard, d, Iq](int no, Index* index) {
|
|
138
138
|
idx_t i0 = (idx_t)no * n / nshard;
|
|
139
139
|
idx_t i1 = ((idx_t)no + 1) * n / nshard;
|
|
140
|
-
const float* x0 = x + i0 * d;
|
|
141
140
|
auto index_ivf = dynamic_cast<IndexIVF*>(index);
|
|
142
141
|
|
|
143
142
|
if (index->verbose) {
|