faiss 0.3.1 → 0.3.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +35 -4
- data/vendor/faiss/faiss/Clustering.h +10 -1
- data/vendor/faiss/faiss/IVFlib.cpp +4 -1
- data/vendor/faiss/faiss/Index.h +21 -6
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
- data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
- data/vendor/faiss/faiss/IndexHNSW.h +52 -3
- data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVF.h +9 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
- data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.h +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
- data/vendor/faiss/faiss/impl/HNSW.h +43 -22
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
- data/vendor/faiss/faiss/impl/io.cpp +13 -5
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
- data/vendor/faiss/faiss/index_factory.cpp +31 -13
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +58 -88
- data/vendor/faiss/faiss/utils/distances.h +5 -5
- data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +10 -3
- data/vendor/faiss/faiss/utils/utils.h +3 -0
- metadata +16 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bdce4ec4f4169dff5f08ccbed2de2750dfd33738fe60d747645f7aaa43187505
|
4
|
+
data.tar.gz: a8ab702eead45525bb4aae8b28b9c20bc0d0d8c774a79ef942a9c8d7a9cabc2f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7e8291961c8a8550e745c55eef5011ca23fc6f5ce7452eeb6da45ebfd020f7c07df70a0a5d7c281e2449214d5ec26102f9194f1aa49d0b9be21304dad3a98368
|
7
|
+
data.tar.gz: 80b475d06b237902b88025dc2602a7e7c8ad15ec757cd43d63d143423eb7a1bd759b8c30715b9ec30c2ae3cfecd2eea502e9814524219d396c71067f0959b62e
|
data/CHANGELOG.md
CHANGED
data/lib/faiss/version.rb
CHANGED
@@ -86,7 +86,7 @@ struct OperatingPoint {
|
|
86
86
|
double perf; ///< performance measure (output of a Criterion)
|
87
87
|
double t; ///< corresponding execution time (ms)
|
88
88
|
std::string key; ///< key that identifies this op pt
|
89
|
-
int64_t cno; ///< integer
|
89
|
+
int64_t cno; ///< integer identifier
|
90
90
|
};
|
91
91
|
|
92
92
|
struct OperatingPoints {
|
@@ -11,6 +11,7 @@
|
|
11
11
|
#include <faiss/VectorTransform.h>
|
12
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
13
13
|
|
14
|
+
#include <chrono>
|
14
15
|
#include <cinttypes>
|
15
16
|
#include <cmath>
|
16
17
|
#include <cstdio>
|
@@ -74,6 +75,14 @@ void Clustering::train(
|
|
74
75
|
|
75
76
|
namespace {
|
76
77
|
|
78
|
+
uint64_t get_actual_rng_seed(const int seed) {
|
79
|
+
return (seed >= 0)
|
80
|
+
? seed
|
81
|
+
: static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
|
82
|
+
.time_since_epoch()
|
83
|
+
.count());
|
84
|
+
}
|
85
|
+
|
77
86
|
idx_t subsample_training_set(
|
78
87
|
const Clustering& clus,
|
79
88
|
idx_t nx,
|
@@ -87,11 +96,30 @@ idx_t subsample_training_set(
|
|
87
96
|
clus.k * clus.max_points_per_centroid,
|
88
97
|
nx);
|
89
98
|
}
|
90
|
-
|
91
|
-
|
99
|
+
|
100
|
+
const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
|
101
|
+
|
102
|
+
std::vector<int> perm;
|
103
|
+
if (clus.use_faster_subsampling) {
|
104
|
+
// use subsampling with splitmix64 rng
|
105
|
+
SplitMix64RandomGenerator rng(actual_seed);
|
106
|
+
|
107
|
+
const idx_t new_nx = clus.k * clus.max_points_per_centroid;
|
108
|
+
perm.resize(new_nx);
|
109
|
+
for (idx_t i = 0; i < new_nx; i++) {
|
110
|
+
perm[i] = rng.rand_int(nx);
|
111
|
+
}
|
112
|
+
} else {
|
113
|
+
// use subsampling with a default std rng
|
114
|
+
perm.resize(nx);
|
115
|
+
rand_perm(perm.data(), nx, actual_seed);
|
116
|
+
}
|
117
|
+
|
92
118
|
nx = clus.k * clus.max_points_per_centroid;
|
93
119
|
uint8_t* x_new = new uint8_t[nx * line_size];
|
94
120
|
*x_out = x_new;
|
121
|
+
|
122
|
+
// might be worth omp-ing as well
|
95
123
|
for (idx_t i = 0; i < nx; i++) {
|
96
124
|
memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
|
97
125
|
}
|
@@ -280,7 +308,7 @@ void Clustering::train_encoded(
|
|
280
308
|
|
281
309
|
double t0 = getmillisecs();
|
282
310
|
|
283
|
-
if (!codec) {
|
311
|
+
if (!codec && check_input_data_for_NaNs) {
|
284
312
|
// Check for NaNs in input data. Normally it is the user's
|
285
313
|
// responsibility, but it may spare us some hard-to-debug
|
286
314
|
// reports.
|
@@ -383,6 +411,9 @@ void Clustering::train_encoded(
|
|
383
411
|
}
|
384
412
|
t0 = getmillisecs();
|
385
413
|
|
414
|
+
// initialize seed
|
415
|
+
const uint64_t actual_seed = get_actual_rng_seed(seed);
|
416
|
+
|
386
417
|
// temporary buffer to decode vectors during the optimization
|
387
418
|
std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
|
388
419
|
|
@@ -395,7 +426,7 @@ void Clustering::train_encoded(
|
|
395
426
|
centroids.resize(d * k);
|
396
427
|
std::vector<int> perm(nx);
|
397
428
|
|
398
|
-
rand_perm(perm.data(), nx,
|
429
|
+
rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
|
399
430
|
|
400
431
|
if (!codec) {
|
401
432
|
for (int i = n_input_centroids; i < k; i++) {
|
@@ -43,11 +43,20 @@ struct ClusteringParameters {
|
|
43
43
|
int min_points_per_centroid = 39;
|
44
44
|
/// to limit size of dataset, otherwise the training set is subsampled
|
45
45
|
int max_points_per_centroid = 256;
|
46
|
-
/// seed for the random number generator
|
46
|
+
/// seed for the random number generator.
|
47
|
+
/// negative values lead to seeding an internal rng with
|
48
|
+
/// std::high_resolution_clock.
|
47
49
|
int seed = 1234;
|
48
50
|
|
49
51
|
/// when the training set is encoded, batch size of the codec decoder
|
50
52
|
size_t decode_block_size = 32768;
|
53
|
+
|
54
|
+
/// whether to check for NaNs in an input data
|
55
|
+
bool check_input_data_for_NaNs = true;
|
56
|
+
|
57
|
+
/// Whether to use splitmix64-based random number generator for subsampling,
|
58
|
+
/// which is faster, but may pick duplicate points.
|
59
|
+
bool use_faster_subsampling = false;
|
51
60
|
};
|
52
61
|
|
53
62
|
struct ClusteringIterationStats {
|
@@ -352,7 +352,10 @@ void search_with_parameters(
|
|
352
352
|
const IndexIVF* index_ivf = dynamic_cast<const IndexIVF*>(index);
|
353
353
|
FAISS_THROW_IF_NOT(index_ivf);
|
354
354
|
|
355
|
-
|
355
|
+
SearchParameters* quantizer_params =
|
356
|
+
(params) ? params->quantizer_params : nullptr;
|
357
|
+
index_ivf->quantizer->search(
|
358
|
+
n, x, params->nprobe, Dq.data(), Iq.data(), quantizer_params);
|
356
359
|
|
357
360
|
if (nb_dis_ptr) {
|
358
361
|
*nb_dis_ptr = count_ndis(index_ivf, n * params->nprobe, Iq.data());
|
data/vendor/faiss/faiss/Index.h
CHANGED
@@ -17,9 +17,21 @@
|
|
17
17
|
#include <typeinfo>
|
18
18
|
|
19
19
|
#define FAISS_VERSION_MAJOR 1
|
20
|
-
#define FAISS_VERSION_MINOR
|
20
|
+
#define FAISS_VERSION_MINOR 9
|
21
21
|
#define FAISS_VERSION_PATCH 0
|
22
22
|
|
23
|
+
// Macro to combine the version components into a single string
|
24
|
+
#ifndef FAISS_STRINGIFY
|
25
|
+
#define FAISS_STRINGIFY(ARG) #ARG
|
26
|
+
#endif
|
27
|
+
#ifndef FAISS_TOSTRING
|
28
|
+
#define FAISS_TOSTRING(ARG) FAISS_STRINGIFY(ARG)
|
29
|
+
#endif
|
30
|
+
#define VERSION_STRING \
|
31
|
+
FAISS_TOSTRING(FAISS_VERSION_MAJOR) \
|
32
|
+
"." FAISS_TOSTRING(FAISS_VERSION_MINOR) "." FAISS_TOSTRING( \
|
33
|
+
FAISS_VERSION_PATCH)
|
34
|
+
|
23
35
|
/**
|
24
36
|
* @namespace faiss
|
25
37
|
*
|
@@ -38,8 +50,8 @@
|
|
38
50
|
|
39
51
|
namespace faiss {
|
40
52
|
|
41
|
-
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
|
42
|
-
/// impl/DistanceComputer.h
|
53
|
+
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
|
54
|
+
/// and impl/DistanceComputer.h
|
43
55
|
struct IDSelector;
|
44
56
|
struct RangeSearchResult;
|
45
57
|
struct DistanceComputer;
|
@@ -56,7 +68,8 @@ struct SearchParameters {
|
|
56
68
|
virtual ~SearchParameters() {}
|
57
69
|
};
|
58
70
|
|
59
|
-
/** Abstract structure for an index, supports adding vectors and searching
|
71
|
+
/** Abstract structure for an index, supports adding vectors and searching
|
72
|
+
* them.
|
60
73
|
*
|
61
74
|
* All vectors provided at add or search time are 32-bit float arrays,
|
62
75
|
* although the internal representation may vary.
|
@@ -154,7 +167,8 @@ struct Index {
|
|
154
167
|
|
155
168
|
/** return the indexes of the k vectors closest to the query x.
|
156
169
|
*
|
157
|
-
* This function is identical as search but only return labels of
|
170
|
+
* This function is identical as search but only return labels of
|
171
|
+
* neighbors.
|
158
172
|
* @param n number of vectors
|
159
173
|
* @param x input vectors to search, size n * d
|
160
174
|
* @param labels output labels of the NNs, size n*k
|
@@ -179,7 +193,8 @@ struct Index {
|
|
179
193
|
*/
|
180
194
|
virtual void reconstruct(idx_t key, float* recons) const;
|
181
195
|
|
182
|
-
/** Reconstruct several stored vectors (or an approximation if lossy
|
196
|
+
/** Reconstruct several stored vectors (or an approximation if lossy
|
197
|
+
* coding)
|
183
198
|
*
|
184
199
|
* this function may not be defined for some indexes
|
185
200
|
* @param n number of vectors to reconstruct
|
@@ -189,6 +189,7 @@ void estimators_from_tables_generic(
|
|
189
189
|
dt += index.ksub;
|
190
190
|
}
|
191
191
|
}
|
192
|
+
|
192
193
|
if (C::cmp(heap_dis[0], dis)) {
|
193
194
|
heap_pop<C>(k, heap_dis, heap_ids);
|
194
195
|
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
@@ -203,17 +204,18 @@ ResultHandlerCompare<C, false>* make_knn_handler(
|
|
203
204
|
idx_t k,
|
204
205
|
size_t ntotal,
|
205
206
|
float* distances,
|
206
|
-
idx_t* labels
|
207
|
+
idx_t* labels,
|
208
|
+
const IDSelector* sel = nullptr) {
|
207
209
|
using HeapHC = HeapHandler<C, false>;
|
208
210
|
using ReservoirHC = ReservoirHandler<C, false>;
|
209
211
|
using SingleResultHC = SingleResultHandler<C, false>;
|
210
212
|
|
211
213
|
if (k == 1) {
|
212
|
-
return new SingleResultHC(n, ntotal, distances, labels);
|
214
|
+
return new SingleResultHC(n, ntotal, distances, labels, sel);
|
213
215
|
} else if (impl % 2 == 0) {
|
214
|
-
return new HeapHC(n, ntotal, k, distances, labels);
|
216
|
+
return new HeapHC(n, ntotal, k, distances, labels, sel);
|
215
217
|
} else /* if (impl % 2 == 1) */ {
|
216
|
-
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels);
|
218
|
+
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
|
217
219
|
}
|
218
220
|
}
|
219
221
|
|
@@ -547,6 +549,22 @@ void IndexFastScan::search_implem_14(
|
|
547
549
|
}
|
548
550
|
}
|
549
551
|
|
552
|
+
template void IndexFastScan::search_dispatch_implem<true>(
|
553
|
+
idx_t n,
|
554
|
+
const float* x,
|
555
|
+
idx_t k,
|
556
|
+
float* distances,
|
557
|
+
idx_t* labels,
|
558
|
+
const NormTableScaler* scaler) const;
|
559
|
+
|
560
|
+
template void IndexFastScan::search_dispatch_implem<false>(
|
561
|
+
idx_t n,
|
562
|
+
const float* x,
|
563
|
+
idx_t k,
|
564
|
+
float* distances,
|
565
|
+
idx_t* labels,
|
566
|
+
const NormTableScaler* scaler) const;
|
567
|
+
|
550
568
|
void IndexFastScan::reconstruct(idx_t key, float* recons) const {
|
551
569
|
std::vector<uint8_t> code(code_size, 0);
|
552
570
|
BitstringWriter bsw(code.data(), code_size);
|
@@ -41,15 +41,19 @@ void IndexFlat::search(
|
|
41
41
|
} else if (metric_type == METRIC_L2) {
|
42
42
|
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
43
43
|
knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
|
44
|
-
} else if (is_similarity_metric(metric_type)) {
|
45
|
-
float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
46
|
-
knn_extra_metrics(
|
47
|
-
x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
|
48
44
|
} else {
|
49
|
-
FAISS_THROW_IF_NOT(!sel);
|
50
|
-
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
45
|
+
FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
|
51
46
|
knn_extra_metrics(
|
52
|
-
x,
|
47
|
+
x,
|
48
|
+
get_xb(),
|
49
|
+
d,
|
50
|
+
n,
|
51
|
+
ntotal,
|
52
|
+
metric_type,
|
53
|
+
metric_arg,
|
54
|
+
k,
|
55
|
+
distances,
|
56
|
+
labels);
|
53
57
|
}
|
54
58
|
}
|
55
59
|
|
@@ -12,6 +12,8 @@
|
|
12
12
|
#include <faiss/impl/DistanceComputer.h>
|
13
13
|
#include <faiss/impl/FaissAssert.h>
|
14
14
|
#include <faiss/impl/IDSelector.h>
|
15
|
+
#include <faiss/impl/ResultHandler.h>
|
16
|
+
#include <faiss/utils/extra_distances.h>
|
15
17
|
|
16
18
|
namespace faiss {
|
17
19
|
|
@@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
|
|
70
72
|
reconstruct_n(key, 1, recons);
|
71
73
|
}
|
72
74
|
|
73
|
-
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
|
74
|
-
const {
|
75
|
-
FAISS_THROW_MSG("not implemented");
|
76
|
-
}
|
77
|
-
|
78
75
|
void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
|
79
76
|
// minimal sanity checks
|
80
77
|
const IndexFlatCodes* other =
|
@@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
|
|
114
111
|
std::swap(codes, new_codes);
|
115
112
|
}
|
116
113
|
|
114
|
+
namespace {
|
115
|
+
|
116
|
+
template <class VD>
|
117
|
+
struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
|
118
|
+
const IndexFlatCodes& codec;
|
119
|
+
const VD vd;
|
120
|
+
// temp buffers
|
121
|
+
std::vector<uint8_t> code_buffer;
|
122
|
+
std::vector<float> vec_buffer;
|
123
|
+
const float* query = nullptr;
|
124
|
+
|
125
|
+
GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
|
126
|
+
: FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
|
127
|
+
codec(*codec),
|
128
|
+
vd(vd),
|
129
|
+
code_buffer(codec->code_size * 4),
|
130
|
+
vec_buffer(codec->d * 4) {}
|
131
|
+
|
132
|
+
void set_query(const float* x) override {
|
133
|
+
query = x;
|
134
|
+
}
|
135
|
+
|
136
|
+
float operator()(idx_t i) override {
|
137
|
+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
|
138
|
+
return vd(query, vec_buffer.data());
|
139
|
+
}
|
140
|
+
|
141
|
+
float distance_to_code(const uint8_t* code) override {
|
142
|
+
codec.sa_decode(1, code, vec_buffer.data());
|
143
|
+
return vd(query, vec_buffer.data());
|
144
|
+
}
|
145
|
+
|
146
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
147
|
+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
|
148
|
+
codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
|
149
|
+
return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
|
150
|
+
}
|
151
|
+
|
152
|
+
void distances_batch_4(
|
153
|
+
const idx_t idx0,
|
154
|
+
const idx_t idx1,
|
155
|
+
const idx_t idx2,
|
156
|
+
const idx_t idx3,
|
157
|
+
float& dis0,
|
158
|
+
float& dis1,
|
159
|
+
float& dis2,
|
160
|
+
float& dis3) override {
|
161
|
+
uint8_t* cp = code_buffer.data();
|
162
|
+
for (idx_t i : {idx0, idx1, idx2, idx3}) {
|
163
|
+
memcpy(cp, codes + i * code_size, code_size);
|
164
|
+
cp += code_size;
|
165
|
+
}
|
166
|
+
// potential benefit is if batch decoding is more efficient than 1 by 1
|
167
|
+
// decoding
|
168
|
+
codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
|
169
|
+
dis0 = vd(query, vec_buffer.data());
|
170
|
+
dis1 = vd(query, vec_buffer.data() + vd.d);
|
171
|
+
dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
|
172
|
+
dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
|
173
|
+
}
|
174
|
+
};
|
175
|
+
|
176
|
+
struct Run_get_distance_computer {
|
177
|
+
using T = FlatCodesDistanceComputer*;
|
178
|
+
|
179
|
+
template <class VD>
|
180
|
+
FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
|
181
|
+
return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
|
182
|
+
}
|
183
|
+
};
|
184
|
+
|
185
|
+
template <class BlockResultHandler>
|
186
|
+
struct Run_search_with_decompress {
|
187
|
+
using T = void;
|
188
|
+
|
189
|
+
template <class VectorDistance>
|
190
|
+
void f(VectorDistance& vd,
|
191
|
+
const IndexFlatCodes* index_ptr,
|
192
|
+
const float* xq,
|
193
|
+
BlockResultHandler& res) {
|
194
|
+
// Note that there seems to be a clang (?) bug that "sometimes" passes
|
195
|
+
// the const Index & parameters by value, so to be on the safe side,
|
196
|
+
// it's better to use pointers.
|
197
|
+
const IndexFlatCodes& index = *index_ptr;
|
198
|
+
size_t ntotal = index.ntotal;
|
199
|
+
using SingleResultHandler =
|
200
|
+
typename BlockResultHandler::SingleResultHandler;
|
201
|
+
using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
|
202
|
+
#pragma omp parallel // if (res.nq > 100)
|
203
|
+
{
|
204
|
+
std::unique_ptr<DC> dc(new DC(&index, vd));
|
205
|
+
SingleResultHandler resi(res);
|
206
|
+
#pragma omp for
|
207
|
+
for (int64_t q = 0; q < res.nq; q++) {
|
208
|
+
resi.begin(q);
|
209
|
+
dc->set_query(xq + vd.d * q);
|
210
|
+
for (size_t i = 0; i < ntotal; i++) {
|
211
|
+
if (res.is_in_selection(i)) {
|
212
|
+
float dis = (*dc)(i);
|
213
|
+
resi.add_result(dis, i);
|
214
|
+
}
|
215
|
+
}
|
216
|
+
resi.end();
|
217
|
+
}
|
218
|
+
}
|
219
|
+
}
|
220
|
+
};
|
221
|
+
|
222
|
+
struct Run_search_with_decompress_res {
|
223
|
+
using T = void;
|
224
|
+
|
225
|
+
template <class ResultHandler>
|
226
|
+
void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
|
227
|
+
Run_search_with_decompress<ResultHandler> r;
|
228
|
+
dispatch_VectorDistance(
|
229
|
+
index->d,
|
230
|
+
index->metric_type,
|
231
|
+
index->metric_arg,
|
232
|
+
r,
|
233
|
+
index,
|
234
|
+
xq,
|
235
|
+
res);
|
236
|
+
}
|
237
|
+
};
|
238
|
+
|
239
|
+
} // anonymous namespace
|
240
|
+
|
241
|
+
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
|
242
|
+
const {
|
243
|
+
Run_get_distance_computer r;
|
244
|
+
return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
|
245
|
+
}
|
246
|
+
|
247
|
+
void IndexFlatCodes::search(
|
248
|
+
idx_t n,
|
249
|
+
const float* x,
|
250
|
+
idx_t k,
|
251
|
+
float* distances,
|
252
|
+
idx_t* labels,
|
253
|
+
const SearchParameters* params) const {
|
254
|
+
Run_search_with_decompress_res r;
|
255
|
+
const IDSelector* sel = params ? params->sel : nullptr;
|
256
|
+
dispatch_knn_ResultHandler(
|
257
|
+
n, distances, labels, k, metric_type, sel, r, this, x);
|
258
|
+
}
|
259
|
+
|
260
|
+
void IndexFlatCodes::range_search(
|
261
|
+
idx_t n,
|
262
|
+
const float* x,
|
263
|
+
float radius,
|
264
|
+
RangeSearchResult* result,
|
265
|
+
const SearchParameters* params) const {
|
266
|
+
const IDSelector* sel = params ? params->sel : nullptr;
|
267
|
+
Run_search_with_decompress_res r;
|
268
|
+
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
|
269
|
+
}
|
270
|
+
|
117
271
|
} // namespace faiss
|
@@ -5,8 +5,6 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
#pragma once
|
11
9
|
|
12
10
|
#include <faiss/Index.h>
|
@@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
|
|
45
43
|
* different from the usual ones: the new ids are shifted */
|
46
44
|
size_t remove_ids(const IDSelector& sel) override;
|
47
45
|
|
48
|
-
/** a FlatCodesDistanceComputer offers a distance_to_code method
|
46
|
+
/** a FlatCodesDistanceComputer offers a distance_to_code method
|
47
|
+
*
|
48
|
+
* The default implementation explicitly decodes the vector with sa_decode.
|
49
|
+
*/
|
49
50
|
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
|
50
51
|
|
51
52
|
DistanceComputer* get_distance_computer() const override {
|
52
53
|
return get_FlatCodesDistanceComputer();
|
53
54
|
}
|
54
55
|
|
56
|
+
/** Search implemented by decoding */
|
57
|
+
void search(
|
58
|
+
idx_t n,
|
59
|
+
const float* x,
|
60
|
+
idx_t k,
|
61
|
+
float* distances,
|
62
|
+
idx_t* labels,
|
63
|
+
const SearchParameters* params = nullptr) const override;
|
64
|
+
|
65
|
+
void range_search(
|
66
|
+
idx_t n,
|
67
|
+
const float* x,
|
68
|
+
float radius,
|
69
|
+
RangeSearchResult* result,
|
70
|
+
const SearchParameters* params = nullptr) const override;
|
71
|
+
|
55
72
|
// returns a new instance of a CodePacker
|
56
73
|
CodePacker* get_CodePacker() const;
|
57
74
|
|