faiss 0.3.1 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +35 -4
- data/vendor/faiss/faiss/Clustering.h +10 -1
- data/vendor/faiss/faiss/IVFlib.cpp +4 -1
- data/vendor/faiss/faiss/Index.h +21 -6
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
- data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
- data/vendor/faiss/faiss/IndexHNSW.h +52 -3
- data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVF.h +9 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
- data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.h +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
- data/vendor/faiss/faiss/impl/HNSW.h +43 -22
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
- data/vendor/faiss/faiss/impl/io.cpp +13 -5
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
- data/vendor/faiss/faiss/index_factory.cpp +31 -13
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +58 -88
- data/vendor/faiss/faiss/utils/distances.h +5 -5
- data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +10 -3
- data/vendor/faiss/faiss/utils/utils.h +3 -0
- metadata +16 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: bdce4ec4f4169dff5f08ccbed2de2750dfd33738fe60d747645f7aaa43187505
|
|
4
|
+
data.tar.gz: a8ab702eead45525bb4aae8b28b9c20bc0d0d8c774a79ef942a9c8d7a9cabc2f
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 7e8291961c8a8550e745c55eef5011ca23fc6f5ce7452eeb6da45ebfd020f7c07df70a0a5d7c281e2449214d5ec26102f9194f1aa49d0b9be21304dad3a98368
|
|
7
|
+
data.tar.gz: 80b475d06b237902b88025dc2602a7e7c8ad15ec757cd43d63d143423eb7a1bd759b8c30715b9ec30c2ae3cfecd2eea502e9814524219d396c71067f0959b62e
|
data/CHANGELOG.md
CHANGED
data/lib/faiss/version.rb
CHANGED
|
@@ -86,7 +86,7 @@ struct OperatingPoint {
|
|
|
86
86
|
double perf; ///< performance measure (output of a Criterion)
|
|
87
87
|
double t; ///< corresponding execution time (ms)
|
|
88
88
|
std::string key; ///< key that identifies this op pt
|
|
89
|
-
int64_t cno; ///< integer
|
|
89
|
+
int64_t cno; ///< integer identifier
|
|
90
90
|
};
|
|
91
91
|
|
|
92
92
|
struct OperatingPoints {
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
#include <faiss/VectorTransform.h>
|
|
12
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
13
13
|
|
|
14
|
+
#include <chrono>
|
|
14
15
|
#include <cinttypes>
|
|
15
16
|
#include <cmath>
|
|
16
17
|
#include <cstdio>
|
|
@@ -74,6 +75,14 @@ void Clustering::train(
|
|
|
74
75
|
|
|
75
76
|
namespace {
|
|
76
77
|
|
|
78
|
+
uint64_t get_actual_rng_seed(const int seed) {
|
|
79
|
+
return (seed >= 0)
|
|
80
|
+
? seed
|
|
81
|
+
: static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
|
|
82
|
+
.time_since_epoch()
|
|
83
|
+
.count());
|
|
84
|
+
}
|
|
85
|
+
|
|
77
86
|
idx_t subsample_training_set(
|
|
78
87
|
const Clustering& clus,
|
|
79
88
|
idx_t nx,
|
|
@@ -87,11 +96,30 @@ idx_t subsample_training_set(
|
|
|
87
96
|
clus.k * clus.max_points_per_centroid,
|
|
88
97
|
nx);
|
|
89
98
|
}
|
|
90
|
-
|
|
91
|
-
|
|
99
|
+
|
|
100
|
+
const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
|
|
101
|
+
|
|
102
|
+
std::vector<int> perm;
|
|
103
|
+
if (clus.use_faster_subsampling) {
|
|
104
|
+
// use subsampling with splitmix64 rng
|
|
105
|
+
SplitMix64RandomGenerator rng(actual_seed);
|
|
106
|
+
|
|
107
|
+
const idx_t new_nx = clus.k * clus.max_points_per_centroid;
|
|
108
|
+
perm.resize(new_nx);
|
|
109
|
+
for (idx_t i = 0; i < new_nx; i++) {
|
|
110
|
+
perm[i] = rng.rand_int(nx);
|
|
111
|
+
}
|
|
112
|
+
} else {
|
|
113
|
+
// use subsampling with a default std rng
|
|
114
|
+
perm.resize(nx);
|
|
115
|
+
rand_perm(perm.data(), nx, actual_seed);
|
|
116
|
+
}
|
|
117
|
+
|
|
92
118
|
nx = clus.k * clus.max_points_per_centroid;
|
|
93
119
|
uint8_t* x_new = new uint8_t[nx * line_size];
|
|
94
120
|
*x_out = x_new;
|
|
121
|
+
|
|
122
|
+
// might be worth omp-ing as well
|
|
95
123
|
for (idx_t i = 0; i < nx; i++) {
|
|
96
124
|
memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
|
|
97
125
|
}
|
|
@@ -280,7 +308,7 @@ void Clustering::train_encoded(
|
|
|
280
308
|
|
|
281
309
|
double t0 = getmillisecs();
|
|
282
310
|
|
|
283
|
-
if (!codec) {
|
|
311
|
+
if (!codec && check_input_data_for_NaNs) {
|
|
284
312
|
// Check for NaNs in input data. Normally it is the user's
|
|
285
313
|
// responsibility, but it may spare us some hard-to-debug
|
|
286
314
|
// reports.
|
|
@@ -383,6 +411,9 @@ void Clustering::train_encoded(
|
|
|
383
411
|
}
|
|
384
412
|
t0 = getmillisecs();
|
|
385
413
|
|
|
414
|
+
// initialize seed
|
|
415
|
+
const uint64_t actual_seed = get_actual_rng_seed(seed);
|
|
416
|
+
|
|
386
417
|
// temporary buffer to decode vectors during the optimization
|
|
387
418
|
std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
|
|
388
419
|
|
|
@@ -395,7 +426,7 @@ void Clustering::train_encoded(
|
|
|
395
426
|
centroids.resize(d * k);
|
|
396
427
|
std::vector<int> perm(nx);
|
|
397
428
|
|
|
398
|
-
rand_perm(perm.data(), nx,
|
|
429
|
+
rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
|
|
399
430
|
|
|
400
431
|
if (!codec) {
|
|
401
432
|
for (int i = n_input_centroids; i < k; i++) {
|
|
@@ -43,11 +43,20 @@ struct ClusteringParameters {
|
|
|
43
43
|
int min_points_per_centroid = 39;
|
|
44
44
|
/// to limit size of dataset, otherwise the training set is subsampled
|
|
45
45
|
int max_points_per_centroid = 256;
|
|
46
|
-
/// seed for the random number generator
|
|
46
|
+
/// seed for the random number generator.
|
|
47
|
+
/// negative values lead to seeding an internal rng with
|
|
48
|
+
/// std::high_resolution_clock.
|
|
47
49
|
int seed = 1234;
|
|
48
50
|
|
|
49
51
|
/// when the training set is encoded, batch size of the codec decoder
|
|
50
52
|
size_t decode_block_size = 32768;
|
|
53
|
+
|
|
54
|
+
/// whether to check for NaNs in an input data
|
|
55
|
+
bool check_input_data_for_NaNs = true;
|
|
56
|
+
|
|
57
|
+
/// Whether to use splitmix64-based random number generator for subsampling,
|
|
58
|
+
/// which is faster, but may pick duplicate points.
|
|
59
|
+
bool use_faster_subsampling = false;
|
|
51
60
|
};
|
|
52
61
|
|
|
53
62
|
struct ClusteringIterationStats {
|
|
@@ -352,7 +352,10 @@ void search_with_parameters(
|
|
|
352
352
|
const IndexIVF* index_ivf = dynamic_cast<const IndexIVF*>(index);
|
|
353
353
|
FAISS_THROW_IF_NOT(index_ivf);
|
|
354
354
|
|
|
355
|
-
|
|
355
|
+
SearchParameters* quantizer_params =
|
|
356
|
+
(params) ? params->quantizer_params : nullptr;
|
|
357
|
+
index_ivf->quantizer->search(
|
|
358
|
+
n, x, params->nprobe, Dq.data(), Iq.data(), quantizer_params);
|
|
356
359
|
|
|
357
360
|
if (nb_dis_ptr) {
|
|
358
361
|
*nb_dis_ptr = count_ndis(index_ivf, n * params->nprobe, Iq.data());
|
data/vendor/faiss/faiss/Index.h
CHANGED
|
@@ -17,9 +17,21 @@
|
|
|
17
17
|
#include <typeinfo>
|
|
18
18
|
|
|
19
19
|
#define FAISS_VERSION_MAJOR 1
|
|
20
|
-
#define FAISS_VERSION_MINOR
|
|
20
|
+
#define FAISS_VERSION_MINOR 9
|
|
21
21
|
#define FAISS_VERSION_PATCH 0
|
|
22
22
|
|
|
23
|
+
// Macro to combine the version components into a single string
|
|
24
|
+
#ifndef FAISS_STRINGIFY
|
|
25
|
+
#define FAISS_STRINGIFY(ARG) #ARG
|
|
26
|
+
#endif
|
|
27
|
+
#ifndef FAISS_TOSTRING
|
|
28
|
+
#define FAISS_TOSTRING(ARG) FAISS_STRINGIFY(ARG)
|
|
29
|
+
#endif
|
|
30
|
+
#define VERSION_STRING \
|
|
31
|
+
FAISS_TOSTRING(FAISS_VERSION_MAJOR) \
|
|
32
|
+
"." FAISS_TOSTRING(FAISS_VERSION_MINOR) "." FAISS_TOSTRING( \
|
|
33
|
+
FAISS_VERSION_PATCH)
|
|
34
|
+
|
|
23
35
|
/**
|
|
24
36
|
* @namespace faiss
|
|
25
37
|
*
|
|
@@ -38,8 +50,8 @@
|
|
|
38
50
|
|
|
39
51
|
namespace faiss {
|
|
40
52
|
|
|
41
|
-
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
|
|
42
|
-
/// impl/DistanceComputer.h
|
|
53
|
+
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
|
|
54
|
+
/// and impl/DistanceComputer.h
|
|
43
55
|
struct IDSelector;
|
|
44
56
|
struct RangeSearchResult;
|
|
45
57
|
struct DistanceComputer;
|
|
@@ -56,7 +68,8 @@ struct SearchParameters {
|
|
|
56
68
|
virtual ~SearchParameters() {}
|
|
57
69
|
};
|
|
58
70
|
|
|
59
|
-
/** Abstract structure for an index, supports adding vectors and searching
|
|
71
|
+
/** Abstract structure for an index, supports adding vectors and searching
|
|
72
|
+
* them.
|
|
60
73
|
*
|
|
61
74
|
* All vectors provided at add or search time are 32-bit float arrays,
|
|
62
75
|
* although the internal representation may vary.
|
|
@@ -154,7 +167,8 @@ struct Index {
|
|
|
154
167
|
|
|
155
168
|
/** return the indexes of the k vectors closest to the query x.
|
|
156
169
|
*
|
|
157
|
-
* This function is identical as search but only return labels of
|
|
170
|
+
* This function is identical as search but only return labels of
|
|
171
|
+
* neighbors.
|
|
158
172
|
* @param n number of vectors
|
|
159
173
|
* @param x input vectors to search, size n * d
|
|
160
174
|
* @param labels output labels of the NNs, size n*k
|
|
@@ -179,7 +193,8 @@ struct Index {
|
|
|
179
193
|
*/
|
|
180
194
|
virtual void reconstruct(idx_t key, float* recons) const;
|
|
181
195
|
|
|
182
|
-
/** Reconstruct several stored vectors (or an approximation if lossy
|
|
196
|
+
/** Reconstruct several stored vectors (or an approximation if lossy
|
|
197
|
+
* coding)
|
|
183
198
|
*
|
|
184
199
|
* this function may not be defined for some indexes
|
|
185
200
|
* @param n number of vectors to reconstruct
|
|
@@ -189,6 +189,7 @@ void estimators_from_tables_generic(
|
|
|
189
189
|
dt += index.ksub;
|
|
190
190
|
}
|
|
191
191
|
}
|
|
192
|
+
|
|
192
193
|
if (C::cmp(heap_dis[0], dis)) {
|
|
193
194
|
heap_pop<C>(k, heap_dis, heap_ids);
|
|
194
195
|
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
|
@@ -203,17 +204,18 @@ ResultHandlerCompare<C, false>* make_knn_handler(
|
|
|
203
204
|
idx_t k,
|
|
204
205
|
size_t ntotal,
|
|
205
206
|
float* distances,
|
|
206
|
-
idx_t* labels
|
|
207
|
+
idx_t* labels,
|
|
208
|
+
const IDSelector* sel = nullptr) {
|
|
207
209
|
using HeapHC = HeapHandler<C, false>;
|
|
208
210
|
using ReservoirHC = ReservoirHandler<C, false>;
|
|
209
211
|
using SingleResultHC = SingleResultHandler<C, false>;
|
|
210
212
|
|
|
211
213
|
if (k == 1) {
|
|
212
|
-
return new SingleResultHC(n, ntotal, distances, labels);
|
|
214
|
+
return new SingleResultHC(n, ntotal, distances, labels, sel);
|
|
213
215
|
} else if (impl % 2 == 0) {
|
|
214
|
-
return new HeapHC(n, ntotal, k, distances, labels);
|
|
216
|
+
return new HeapHC(n, ntotal, k, distances, labels, sel);
|
|
215
217
|
} else /* if (impl % 2 == 1) */ {
|
|
216
|
-
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels);
|
|
218
|
+
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
|
|
217
219
|
}
|
|
218
220
|
}
|
|
219
221
|
|
|
@@ -547,6 +549,22 @@ void IndexFastScan::search_implem_14(
|
|
|
547
549
|
}
|
|
548
550
|
}
|
|
549
551
|
|
|
552
|
+
template void IndexFastScan::search_dispatch_implem<true>(
|
|
553
|
+
idx_t n,
|
|
554
|
+
const float* x,
|
|
555
|
+
idx_t k,
|
|
556
|
+
float* distances,
|
|
557
|
+
idx_t* labels,
|
|
558
|
+
const NormTableScaler* scaler) const;
|
|
559
|
+
|
|
560
|
+
template void IndexFastScan::search_dispatch_implem<false>(
|
|
561
|
+
idx_t n,
|
|
562
|
+
const float* x,
|
|
563
|
+
idx_t k,
|
|
564
|
+
float* distances,
|
|
565
|
+
idx_t* labels,
|
|
566
|
+
const NormTableScaler* scaler) const;
|
|
567
|
+
|
|
550
568
|
void IndexFastScan::reconstruct(idx_t key, float* recons) const {
|
|
551
569
|
std::vector<uint8_t> code(code_size, 0);
|
|
552
570
|
BitstringWriter bsw(code.data(), code_size);
|
|
@@ -41,15 +41,19 @@ void IndexFlat::search(
|
|
|
41
41
|
} else if (metric_type == METRIC_L2) {
|
|
42
42
|
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
|
43
43
|
knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
|
|
44
|
-
} else if (is_similarity_metric(metric_type)) {
|
|
45
|
-
float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
|
46
|
-
knn_extra_metrics(
|
|
47
|
-
x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
|
|
48
44
|
} else {
|
|
49
|
-
FAISS_THROW_IF_NOT(!sel);
|
|
50
|
-
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
|
45
|
+
FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
|
|
51
46
|
knn_extra_metrics(
|
|
52
|
-
x,
|
|
47
|
+
x,
|
|
48
|
+
get_xb(),
|
|
49
|
+
d,
|
|
50
|
+
n,
|
|
51
|
+
ntotal,
|
|
52
|
+
metric_type,
|
|
53
|
+
metric_arg,
|
|
54
|
+
k,
|
|
55
|
+
distances,
|
|
56
|
+
labels);
|
|
53
57
|
}
|
|
54
58
|
}
|
|
55
59
|
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
#include <faiss/impl/DistanceComputer.h>
|
|
13
13
|
#include <faiss/impl/FaissAssert.h>
|
|
14
14
|
#include <faiss/impl/IDSelector.h>
|
|
15
|
+
#include <faiss/impl/ResultHandler.h>
|
|
16
|
+
#include <faiss/utils/extra_distances.h>
|
|
15
17
|
|
|
16
18
|
namespace faiss {
|
|
17
19
|
|
|
@@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
|
|
|
70
72
|
reconstruct_n(key, 1, recons);
|
|
71
73
|
}
|
|
72
74
|
|
|
73
|
-
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
|
|
74
|
-
const {
|
|
75
|
-
FAISS_THROW_MSG("not implemented");
|
|
76
|
-
}
|
|
77
|
-
|
|
78
75
|
void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
|
|
79
76
|
// minimal sanity checks
|
|
80
77
|
const IndexFlatCodes* other =
|
|
@@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
|
|
|
114
111
|
std::swap(codes, new_codes);
|
|
115
112
|
}
|
|
116
113
|
|
|
114
|
+
namespace {
|
|
115
|
+
|
|
116
|
+
template <class VD>
|
|
117
|
+
struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
|
|
118
|
+
const IndexFlatCodes& codec;
|
|
119
|
+
const VD vd;
|
|
120
|
+
// temp buffers
|
|
121
|
+
std::vector<uint8_t> code_buffer;
|
|
122
|
+
std::vector<float> vec_buffer;
|
|
123
|
+
const float* query = nullptr;
|
|
124
|
+
|
|
125
|
+
GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
|
|
126
|
+
: FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
|
|
127
|
+
codec(*codec),
|
|
128
|
+
vd(vd),
|
|
129
|
+
code_buffer(codec->code_size * 4),
|
|
130
|
+
vec_buffer(codec->d * 4) {}
|
|
131
|
+
|
|
132
|
+
void set_query(const float* x) override {
|
|
133
|
+
query = x;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
float operator()(idx_t i) override {
|
|
137
|
+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
|
|
138
|
+
return vd(query, vec_buffer.data());
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
float distance_to_code(const uint8_t* code) override {
|
|
142
|
+
codec.sa_decode(1, code, vec_buffer.data());
|
|
143
|
+
return vd(query, vec_buffer.data());
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
147
|
+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
|
|
148
|
+
codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
|
|
149
|
+
return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
void distances_batch_4(
|
|
153
|
+
const idx_t idx0,
|
|
154
|
+
const idx_t idx1,
|
|
155
|
+
const idx_t idx2,
|
|
156
|
+
const idx_t idx3,
|
|
157
|
+
float& dis0,
|
|
158
|
+
float& dis1,
|
|
159
|
+
float& dis2,
|
|
160
|
+
float& dis3) override {
|
|
161
|
+
uint8_t* cp = code_buffer.data();
|
|
162
|
+
for (idx_t i : {idx0, idx1, idx2, idx3}) {
|
|
163
|
+
memcpy(cp, codes + i * code_size, code_size);
|
|
164
|
+
cp += code_size;
|
|
165
|
+
}
|
|
166
|
+
// potential benefit is if batch decoding is more efficient than 1 by 1
|
|
167
|
+
// decoding
|
|
168
|
+
codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
|
|
169
|
+
dis0 = vd(query, vec_buffer.data());
|
|
170
|
+
dis1 = vd(query, vec_buffer.data() + vd.d);
|
|
171
|
+
dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
|
|
172
|
+
dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
|
|
173
|
+
}
|
|
174
|
+
};
|
|
175
|
+
|
|
176
|
+
struct Run_get_distance_computer {
|
|
177
|
+
using T = FlatCodesDistanceComputer*;
|
|
178
|
+
|
|
179
|
+
template <class VD>
|
|
180
|
+
FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
|
|
181
|
+
return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
|
|
182
|
+
}
|
|
183
|
+
};
|
|
184
|
+
|
|
185
|
+
template <class BlockResultHandler>
|
|
186
|
+
struct Run_search_with_decompress {
|
|
187
|
+
using T = void;
|
|
188
|
+
|
|
189
|
+
template <class VectorDistance>
|
|
190
|
+
void f(VectorDistance& vd,
|
|
191
|
+
const IndexFlatCodes* index_ptr,
|
|
192
|
+
const float* xq,
|
|
193
|
+
BlockResultHandler& res) {
|
|
194
|
+
// Note that there seems to be a clang (?) bug that "sometimes" passes
|
|
195
|
+
// the const Index & parameters by value, so to be on the safe side,
|
|
196
|
+
// it's better to use pointers.
|
|
197
|
+
const IndexFlatCodes& index = *index_ptr;
|
|
198
|
+
size_t ntotal = index.ntotal;
|
|
199
|
+
using SingleResultHandler =
|
|
200
|
+
typename BlockResultHandler::SingleResultHandler;
|
|
201
|
+
using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
|
|
202
|
+
#pragma omp parallel // if (res.nq > 100)
|
|
203
|
+
{
|
|
204
|
+
std::unique_ptr<DC> dc(new DC(&index, vd));
|
|
205
|
+
SingleResultHandler resi(res);
|
|
206
|
+
#pragma omp for
|
|
207
|
+
for (int64_t q = 0; q < res.nq; q++) {
|
|
208
|
+
resi.begin(q);
|
|
209
|
+
dc->set_query(xq + vd.d * q);
|
|
210
|
+
for (size_t i = 0; i < ntotal; i++) {
|
|
211
|
+
if (res.is_in_selection(i)) {
|
|
212
|
+
float dis = (*dc)(i);
|
|
213
|
+
resi.add_result(dis, i);
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
resi.end();
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
};
|
|
221
|
+
|
|
222
|
+
struct Run_search_with_decompress_res {
|
|
223
|
+
using T = void;
|
|
224
|
+
|
|
225
|
+
template <class ResultHandler>
|
|
226
|
+
void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
|
|
227
|
+
Run_search_with_decompress<ResultHandler> r;
|
|
228
|
+
dispatch_VectorDistance(
|
|
229
|
+
index->d,
|
|
230
|
+
index->metric_type,
|
|
231
|
+
index->metric_arg,
|
|
232
|
+
r,
|
|
233
|
+
index,
|
|
234
|
+
xq,
|
|
235
|
+
res);
|
|
236
|
+
}
|
|
237
|
+
};
|
|
238
|
+
|
|
239
|
+
} // anonymous namespace
|
|
240
|
+
|
|
241
|
+
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
|
|
242
|
+
const {
|
|
243
|
+
Run_get_distance_computer r;
|
|
244
|
+
return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
void IndexFlatCodes::search(
|
|
248
|
+
idx_t n,
|
|
249
|
+
const float* x,
|
|
250
|
+
idx_t k,
|
|
251
|
+
float* distances,
|
|
252
|
+
idx_t* labels,
|
|
253
|
+
const SearchParameters* params) const {
|
|
254
|
+
Run_search_with_decompress_res r;
|
|
255
|
+
const IDSelector* sel = params ? params->sel : nullptr;
|
|
256
|
+
dispatch_knn_ResultHandler(
|
|
257
|
+
n, distances, labels, k, metric_type, sel, r, this, x);
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
void IndexFlatCodes::range_search(
|
|
261
|
+
idx_t n,
|
|
262
|
+
const float* x,
|
|
263
|
+
float radius,
|
|
264
|
+
RangeSearchResult* result,
|
|
265
|
+
const SearchParameters* params) const {
|
|
266
|
+
const IDSelector* sel = params ? params->sel : nullptr;
|
|
267
|
+
Run_search_with_decompress_res r;
|
|
268
|
+
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
|
|
269
|
+
}
|
|
270
|
+
|
|
117
271
|
} // namespace faiss
|
|
@@ -5,8 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#pragma once
|
|
11
9
|
|
|
12
10
|
#include <faiss/Index.h>
|
|
@@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
|
|
|
45
43
|
* different from the usual ones: the new ids are shifted */
|
|
46
44
|
size_t remove_ids(const IDSelector& sel) override;
|
|
47
45
|
|
|
48
|
-
/** a FlatCodesDistanceComputer offers a distance_to_code method
|
|
46
|
+
/** a FlatCodesDistanceComputer offers a distance_to_code method
|
|
47
|
+
*
|
|
48
|
+
* The default implementation explicitly decodes the vector with sa_decode.
|
|
49
|
+
*/
|
|
49
50
|
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
|
|
50
51
|
|
|
51
52
|
DistanceComputer* get_distance_computer() const override {
|
|
52
53
|
return get_FlatCodesDistanceComputer();
|
|
53
54
|
}
|
|
54
55
|
|
|
56
|
+
/** Search implemented by decoding */
|
|
57
|
+
void search(
|
|
58
|
+
idx_t n,
|
|
59
|
+
const float* x,
|
|
60
|
+
idx_t k,
|
|
61
|
+
float* distances,
|
|
62
|
+
idx_t* labels,
|
|
63
|
+
const SearchParameters* params = nullptr) const override;
|
|
64
|
+
|
|
65
|
+
void range_search(
|
|
66
|
+
idx_t n,
|
|
67
|
+
const float* x,
|
|
68
|
+
float radius,
|
|
69
|
+
RangeSearchResult* result,
|
|
70
|
+
const SearchParameters* params = nullptr) const override;
|
|
71
|
+
|
|
55
72
|
// returns a new instance of a CodePacker
|
|
56
73
|
CodePacker* get_CodePacker() const;
|
|
57
74
|
|