faiss 0.3.4 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +11 -8
- data/vendor/faiss/faiss/Clustering.cpp +0 -16
- data/vendor/faiss/faiss/IVFlib.cpp +213 -0
- data/vendor/faiss/faiss/IVFlib.h +42 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -7
- data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +4 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +13 -20
- data/vendor/faiss/faiss/IndexHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +20 -3
- data/vendor/faiss/faiss/IndexIVF.h +5 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +2 -1
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +277 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +70 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +148 -0
- data/vendor/faiss/faiss/IndexRaBitQ.h +65 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -1
- data/vendor/faiss/faiss/clone_index.cpp +38 -3
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +19 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +4 -11
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +13 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +112 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +35 -13
- data/vendor/faiss/faiss/impl/HNSW.h +5 -4
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +519 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +78 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +3 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +220 -25
- data/vendor/faiss/faiss/impl/index_write.cpp +29 -0
- data/vendor/faiss/faiss/impl/io.h +2 -2
- data/vendor/faiss/faiss/impl/io_macros.h +2 -0
- data/vendor/faiss/faiss/impl/mapped_io.cpp +313 -0
- data/vendor/faiss/faiss/impl/mapped_io.h +51 -0
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +316 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +7 -3
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +1 -1
- data/vendor/faiss/faiss/impl/zerocopy_io.cpp +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +32 -0
- data/vendor/faiss/faiss/index_factory.cpp +16 -5
- data/vendor/faiss/faiss/index_io.h +4 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.h +5 -3
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +24 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +22 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +30 -12
- data/vendor/faiss/faiss/utils/hamming.cpp +45 -21
- data/vendor/faiss/faiss/utils/hamming.h +7 -3
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +16 -4
@@ -11,7 +11,7 @@
|
|
11
11
|
#include <cstdint>
|
12
12
|
#include <cstdio>
|
13
13
|
|
14
|
-
#ifdef
|
14
|
+
#ifdef _WIN32
|
15
15
|
|
16
16
|
/*******************************************************
|
17
17
|
* Windows specific macros
|
@@ -23,11 +23,11 @@
|
|
23
23
|
#define FAISS_API __declspec(dllimport)
|
24
24
|
#endif // FAISS_MAIN_LIB
|
25
25
|
|
26
|
-
#ifdef _MSC_VER
|
27
26
|
#define strtok_r strtok_s
|
28
|
-
#endif // _MSC_VER
|
29
27
|
|
28
|
+
#ifdef _MSC_VER
|
30
29
|
#define __PRETTY_FUNCTION__ __FUNCSIG__
|
30
|
+
#endif // _MSC_VER
|
31
31
|
|
32
32
|
#define posix_memalign(p, a, s) \
|
33
33
|
(((*(p)) = _aligned_malloc((s), (a))), *(p) ? 0 : errno)
|
@@ -37,6 +37,7 @@
|
|
37
37
|
#define ALIGNED(x) __declspec(align(x))
|
38
38
|
|
39
39
|
// redefine the GCC intrinsics with Windows equivalents
|
40
|
+
#ifdef _MSC_VER
|
40
41
|
|
41
42
|
#include <intrin.h>
|
42
43
|
#include <limits.h>
|
@@ -75,6 +76,7 @@ inline int __builtin_clzll(uint64_t x) {
|
|
75
76
|
|
76
77
|
#define __builtin_popcount __popcnt
|
77
78
|
#define __builtin_popcountl __popcnt64
|
79
|
+
#define __builtin_popcountll __popcnt64
|
78
80
|
|
79
81
|
#ifndef __clang__
|
80
82
|
#define __m128i_u __m128i
|
@@ -101,6 +103,8 @@ inline int __builtin_clzll(uint64_t x) {
|
|
101
103
|
#define __F16C__ 1
|
102
104
|
#endif
|
103
105
|
|
106
|
+
#endif // _MSC_VER
|
107
|
+
|
104
108
|
#define FAISS_ALWAYS_INLINE __forceinline
|
105
109
|
|
106
110
|
#else
|
@@ -576,7 +576,7 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
|
|
576
576
|
normalizers = norms;
|
577
577
|
for (int q = 0; q < nq; ++q) {
|
578
578
|
thresholds[q] =
|
579
|
-
normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
|
579
|
+
int(normalizers[2 * q] * (radius - normalizers[2 * q + 1]));
|
580
580
|
}
|
581
581
|
}
|
582
582
|
|
@@ -0,0 +1,67 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <faiss/impl/zerocopy_io.h>
|
9
|
+
#include <cstring>
|
10
|
+
|
11
|
+
namespace faiss {
|
12
|
+
|
13
|
+
ZeroCopyIOReader::ZeroCopyIOReader(uint8_t* data, size_t size)
|
14
|
+
: data_(data), rp_(0), total_(size) {}
|
15
|
+
|
16
|
+
ZeroCopyIOReader::~ZeroCopyIOReader() {}
|
17
|
+
|
18
|
+
size_t ZeroCopyIOReader::get_data_view(void** ptr, size_t size, size_t nitems) {
|
19
|
+
if (size == 0) {
|
20
|
+
return nitems;
|
21
|
+
}
|
22
|
+
|
23
|
+
size_t actual_size = size * nitems;
|
24
|
+
if (rp_ + size * nitems > total_) {
|
25
|
+
actual_size = total_ - rp_;
|
26
|
+
}
|
27
|
+
|
28
|
+
size_t actual_nitems = (actual_size + size - 1) / size;
|
29
|
+
if (actual_nitems == 0) {
|
30
|
+
return 0;
|
31
|
+
}
|
32
|
+
|
33
|
+
// get an address
|
34
|
+
*ptr = (void*)(reinterpret_cast<const char*>(data_ + rp_));
|
35
|
+
|
36
|
+
// alter pos
|
37
|
+
rp_ += size * actual_nitems;
|
38
|
+
|
39
|
+
return actual_nitems;
|
40
|
+
}
|
41
|
+
|
42
|
+
void ZeroCopyIOReader::reset() {
|
43
|
+
rp_ = 0;
|
44
|
+
}
|
45
|
+
|
46
|
+
size_t ZeroCopyIOReader::operator()(void* ptr, size_t size, size_t nitems) {
|
47
|
+
if (size * nitems == 0) {
|
48
|
+
return 0;
|
49
|
+
}
|
50
|
+
|
51
|
+
if (rp_ >= total_) {
|
52
|
+
return 0;
|
53
|
+
}
|
54
|
+
size_t nremain = (total_ - rp_) / size;
|
55
|
+
if (nremain < nitems) {
|
56
|
+
nitems = nremain;
|
57
|
+
}
|
58
|
+
memcpy(ptr, (data_ + rp_), size * nitems);
|
59
|
+
rp_ += size * nitems;
|
60
|
+
return nitems;
|
61
|
+
}
|
62
|
+
|
63
|
+
int ZeroCopyIOReader::filedescriptor() {
|
64
|
+
return -1; // Indicating no file descriptor available for memory buffer
|
65
|
+
}
|
66
|
+
|
67
|
+
} // namespace faiss
|
@@ -0,0 +1,32 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <cstdint>
|
11
|
+
|
12
|
+
#include <faiss/impl/io.h>
|
13
|
+
|
14
|
+
namespace faiss {
|
15
|
+
|
16
|
+
// ZeroCopyIOReader just maps the data from a given pointer.
|
17
|
+
struct ZeroCopyIOReader : public faiss::IOReader {
|
18
|
+
uint8_t* data_;
|
19
|
+
size_t rp_ = 0;
|
20
|
+
size_t total_ = 0;
|
21
|
+
|
22
|
+
ZeroCopyIOReader(uint8_t* data, size_t size);
|
23
|
+
~ZeroCopyIOReader();
|
24
|
+
|
25
|
+
void reset();
|
26
|
+
size_t get_data_view(void** ptr, size_t size, size_t nitems);
|
27
|
+
size_t operator()(void* ptr, size_t size, size_t nitems) override;
|
28
|
+
|
29
|
+
int filedescriptor() override;
|
30
|
+
};
|
31
|
+
|
32
|
+
} // namespace faiss
|
@@ -11,9 +11,6 @@
|
|
11
11
|
|
12
12
|
#include <faiss/index_factory.h>
|
13
13
|
|
14
|
-
#include <cinttypes>
|
15
|
-
#include <cmath>
|
16
|
-
|
17
14
|
#include <map>
|
18
15
|
|
19
16
|
#include <regex>
|
@@ -33,6 +30,7 @@
|
|
33
30
|
#include <faiss/IndexIVFPQ.h>
|
34
31
|
#include <faiss/IndexIVFPQFastScan.h>
|
35
32
|
#include <faiss/IndexIVFPQR.h>
|
33
|
+
#include <faiss/IndexIVFRaBitQ.h>
|
36
34
|
#include <faiss/IndexIVFSpectralHash.h>
|
37
35
|
#include <faiss/IndexLSH.h>
|
38
36
|
#include <faiss/IndexLattice.h>
|
@@ -40,6 +38,7 @@
|
|
40
38
|
#include <faiss/IndexPQ.h>
|
41
39
|
#include <faiss/IndexPQFastScan.h>
|
42
40
|
#include <faiss/IndexPreTransform.h>
|
41
|
+
#include <faiss/IndexRaBitQ.h>
|
43
42
|
#include <faiss/IndexRefine.h>
|
44
43
|
#include <faiss/IndexRowwiseMinMax.h>
|
45
44
|
#include <faiss/IndexScalarQuantizer.h>
|
@@ -67,6 +66,7 @@ namespace {
|
|
67
66
|
*/
|
68
67
|
|
69
68
|
bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
|
69
|
+
// @lint-ignore CLANGTIDY
|
70
70
|
return std::regex_match(s, sm, std::regex(pat));
|
71
71
|
}
|
72
72
|
|
@@ -164,7 +164,7 @@ const std::string aq_norm_pattern =
|
|
164
164
|
const std::string paq_def_pattern = "([0-9]+)x([0-9]+)x([0-9]+)";
|
165
165
|
|
166
166
|
AdditiveQuantizer::Search_type_t aq_parse_search_type(
|
167
|
-
std::string stok,
|
167
|
+
const std::string& stok,
|
168
168
|
MetricType metric) {
|
169
169
|
if (stok == "") {
|
170
170
|
return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
|
@@ -177,6 +177,7 @@ AdditiveQuantizer::Search_type_t aq_parse_search_type(
|
|
177
177
|
std::vector<size_t> aq_parse_nbits(std::string stok) {
|
178
178
|
std::vector<size_t> nbits;
|
179
179
|
std::smatch sm;
|
180
|
+
// @lint-ignore CLANGTIDY
|
180
181
|
while (std::regex_search(stok, sm, std::regex("[^q]([0-9]+)x([0-9]+)"))) {
|
181
182
|
int M = std::stoi(sm[1].str());
|
182
183
|
int nbit = std::stoi(sm[2].str());
|
@@ -186,6 +187,8 @@ std::vector<size_t> aq_parse_nbits(std::string stok) {
|
|
186
187
|
return nbits;
|
187
188
|
}
|
188
189
|
|
190
|
+
const std::string rabitq_pattern = "(RaBitQ)";
|
191
|
+
|
189
192
|
/***************************************************************
|
190
193
|
* Parse VectorTransform
|
191
194
|
*/
|
@@ -436,6 +439,9 @@ IndexIVF* parse_IndexIVF(
|
|
436
439
|
}
|
437
440
|
return index_ivf;
|
438
441
|
}
|
442
|
+
if (match(rabitq_pattern)) {
|
443
|
+
return new IndexIVFRaBitQ(get_q(), d, nlist, mt);
|
444
|
+
}
|
439
445
|
return nullptr;
|
440
446
|
}
|
441
447
|
|
@@ -657,6 +663,11 @@ Index* parse_other_indexes(
|
|
657
663
|
}
|
658
664
|
}
|
659
665
|
|
666
|
+
// IndexRaBitQ
|
667
|
+
if (match(rabitq_pattern)) {
|
668
|
+
return new IndexRaBitQ(d, metric);
|
669
|
+
}
|
670
|
+
|
660
671
|
return nullptr;
|
661
672
|
}
|
662
673
|
|
@@ -766,7 +777,7 @@ std::unique_ptr<Index> index_factory_sub(
|
|
766
777
|
}
|
767
778
|
|
768
779
|
if (verbose) {
|
769
|
-
printf("after () normalization: %s %
|
780
|
+
printf("after () normalization: %s %zd parenthesis indexes d=%d\n",
|
770
781
|
description.c_str(),
|
771
782
|
parenthesis_indexes.size(),
|
772
783
|
d);
|
@@ -62,6 +62,10 @@ const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
|
|
62
62
|
// try to memmap data (useful to load an ArrayInvertedLists as an
|
63
63
|
// OnDiskInvertedLists)
|
64
64
|
const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;
|
65
|
+
// mmap that handles codes for IndexFlatCodes-derived indices and HNSW.
|
66
|
+
// this is a temporary solution, it is expected to be merged with IO_FLAG_MMAP
|
67
|
+
// after OnDiskInvertedLists get properly updated.
|
68
|
+
const int IO_FLAG_MMAP_IFC = 1 << 9;
|
65
69
|
|
66
70
|
Index* read_index(const char* fname, int io_flags = 0);
|
67
71
|
Index* read_index(FILE* f, int io_flags = 0);
|
@@ -181,7 +181,7 @@ size_t InvertedLists::copy_subset_to(
|
|
181
181
|
}
|
182
182
|
|
183
183
|
double InvertedLists::imbalance_factor() const {
|
184
|
-
std::vector<
|
184
|
+
std::vector<int64_t> hist(nlist);
|
185
185
|
|
186
186
|
for (size_t i = 0; i < nlist; i++) {
|
187
187
|
hist[i] = list_size(i);
|
@@ -330,8 +330,8 @@ void ArrayInvertedLists::update_entries(
|
|
330
330
|
}
|
331
331
|
|
332
332
|
void ArrayInvertedLists::permute_invlists(const idx_t* map) {
|
333
|
-
std::vector<
|
334
|
-
std::vector<
|
333
|
+
std::vector<MaybeOwnedVector<uint8_t>> new_codes(nlist);
|
334
|
+
std::vector<MaybeOwnedVector<idx_t>> new_ids(nlist);
|
335
335
|
|
336
336
|
for (size_t i = 0; i < nlist; i++) {
|
337
337
|
size_t o = map[i];
|
@@ -15,9 +15,11 @@
|
|
15
15
|
* the interface.
|
16
16
|
*/
|
17
17
|
|
18
|
-
#include <faiss/MetricType.h>
|
19
18
|
#include <vector>
|
20
19
|
|
20
|
+
#include <faiss/MetricType.h>
|
21
|
+
#include <faiss/impl/maybe_owned_vector.h>
|
22
|
+
|
21
23
|
namespace faiss {
|
22
24
|
|
23
25
|
struct InvertedListsIterator {
|
@@ -241,8 +243,8 @@ struct InvertedLists {
|
|
241
243
|
|
242
244
|
/// simple (default) implementation as an array of inverted lists
|
243
245
|
struct ArrayInvertedLists : InvertedLists {
|
244
|
-
std::vector<
|
245
|
-
std::vector<
|
246
|
+
std::vector<MaybeOwnedVector<uint8_t>> codes; // binary codes, size nlist
|
247
|
+
std::vector<MaybeOwnedVector<idx_t>> ids; ///< Inverted lists for indexes
|
246
248
|
|
247
249
|
ArrayInvertedLists(size_t nlist, size_t code_size);
|
248
250
|
|
@@ -13,9 +13,9 @@
|
|
13
13
|
|
14
14
|
#include <faiss/invlists/BlockInvertedLists.h>
|
15
15
|
|
16
|
-
#ifndef
|
16
|
+
#ifndef _WIN32
|
17
17
|
#include <faiss/invlists/OnDiskInvertedLists.h>
|
18
|
-
#endif // !
|
18
|
+
#endif // !_WIN32
|
19
19
|
|
20
20
|
namespace faiss {
|
21
21
|
|
@@ -33,7 +33,7 @@ namespace {
|
|
33
33
|
/// std::vector that deletes its contents
|
34
34
|
struct IOHookTable : std::vector<InvertedListsIOHook*> {
|
35
35
|
IOHookTable() {
|
36
|
-
#ifndef
|
36
|
+
#ifndef _WIN32
|
37
37
|
push_back(new OnDiskInvertedListsIOHook());
|
38
38
|
#endif
|
39
39
|
push_back(new BlockInvertedListsIOHook());
|
@@ -134,3 +134,27 @@ PyCallbackIDSelector::~PyCallbackIDSelector() {
|
|
134
134
|
PyThreadLock gil;
|
135
135
|
Py_DECREF(callback);
|
136
136
|
}
|
137
|
+
|
138
|
+
/***********************************************************
|
139
|
+
* Callbacks for IVF index sharding
|
140
|
+
***********************************************************/
|
141
|
+
|
142
|
+
PyCallbackShardingFunction::PyCallbackShardingFunction(PyObject* callback)
|
143
|
+
: callback(callback) {
|
144
|
+
PyThreadLock gil;
|
145
|
+
Py_INCREF(callback);
|
146
|
+
}
|
147
|
+
|
148
|
+
int64_t PyCallbackShardingFunction::operator()(int64_t i, int64_t shard_count) {
|
149
|
+
PyThreadLock gil;
|
150
|
+
PyObject* shard_id = PyObject_CallFunction(callback, "LL", i, shard_count);
|
151
|
+
if (shard_id == nullptr) {
|
152
|
+
FAISS_THROW_MSG("propagate py error");
|
153
|
+
}
|
154
|
+
return PyLong_AsLongLong(shard_id);
|
155
|
+
}
|
156
|
+
|
157
|
+
PyCallbackShardingFunction::~PyCallbackShardingFunction() {
|
158
|
+
PyThreadLock gil;
|
159
|
+
Py_DECREF(callback);
|
160
|
+
}
|
@@ -7,6 +7,7 @@
|
|
7
7
|
|
8
8
|
#pragma once
|
9
9
|
|
10
|
+
#include <faiss/IVFlib.h>
|
10
11
|
#include <faiss/impl/IDSelector.h>
|
11
12
|
#include <faiss/impl/io.h>
|
12
13
|
#include <faiss/invlists/InvertedLists.h>
|
@@ -58,3 +59,24 @@ struct PyCallbackIDSelector : faiss::IDSelector {
|
|
58
59
|
|
59
60
|
~PyCallbackIDSelector() override;
|
60
61
|
};
|
62
|
+
|
63
|
+
/***********************************************************
|
64
|
+
* Callbacks for IVF index sharding
|
65
|
+
***********************************************************/
|
66
|
+
|
67
|
+
struct PyCallbackShardingFunction : faiss::ivflib::ShardingFunction {
|
68
|
+
PyObject* callback;
|
69
|
+
|
70
|
+
explicit PyCallbackShardingFunction(PyObject* callback);
|
71
|
+
|
72
|
+
int64_t operator()(int64_t i, int64_t shard_count) override;
|
73
|
+
|
74
|
+
~PyCallbackShardingFunction() override;
|
75
|
+
|
76
|
+
PyCallbackShardingFunction(const PyCallbackShardingFunction&) = delete;
|
77
|
+
PyCallbackShardingFunction(PyCallbackShardingFunction&&) noexcept = default;
|
78
|
+
PyCallbackShardingFunction& operator=(const PyCallbackShardingFunction&) =
|
79
|
+
default;
|
80
|
+
PyCallbackShardingFunction& operator=(PyCallbackShardingFunction&&) =
|
81
|
+
default;
|
82
|
+
};
|
@@ -46,9 +46,11 @@ struct HeapWithBucketsForHamming32<
|
|
46
46
|
// output distances
|
47
47
|
int* const __restrict bh_val,
|
48
48
|
// output indices, each being within [0, n) range
|
49
|
-
int64_t* const __restrict bh_ids
|
49
|
+
int64_t* const __restrict bh_ids,
|
50
|
+
// optional id selector for filtering
|
51
|
+
const IDSelector* sel = nullptr) {
|
50
52
|
// forward a call to bs_addn with 1 beam
|
51
|
-
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
|
53
|
+
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids, sel);
|
52
54
|
}
|
53
55
|
|
54
56
|
static void bs_addn(
|
@@ -66,7 +68,9 @@ struct HeapWithBucketsForHamming32<
|
|
66
68
|
int* const __restrict bh_val,
|
67
69
|
// output indices, each being within [0, n_per_beam * beam_size)
|
68
70
|
// range
|
69
|
-
int64_t* const __restrict bh_ids
|
71
|
+
int64_t* const __restrict bh_ids,
|
72
|
+
// optional id selector for filtering
|
73
|
+
const IDSelector* sel = nullptr) {
|
70
74
|
//
|
71
75
|
using C = CMax<int, int64_t>;
|
72
76
|
|
@@ -95,11 +99,22 @@ struct HeapWithBucketsForHamming32<
|
|
95
99
|
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
|
96
100
|
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
|
97
101
|
uint32_t hamming_distances[8];
|
102
|
+
uint8_t valid_counter = 0;
|
98
103
|
for (size_t j8 = 0; j8 < 8; j8++) {
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
104
|
+
const uint32_t idx =
|
105
|
+
j8 + j * 8 + ip + n_per_beam * beam_index;
|
106
|
+
if (!sel || sel->is_member(idx)) {
|
107
|
+
hamming_distances[j8] = hc.hamming(
|
108
|
+
binary_vectors + idx * code_size);
|
109
|
+
valid_counter++;
|
110
|
+
} else {
|
111
|
+
hamming_distances[j8] =
|
112
|
+
std::numeric_limits<int32_t>::max();
|
113
|
+
}
|
114
|
+
}
|
115
|
+
|
116
|
+
if (valid_counter == 8) {
|
117
|
+
continue; // Skip if all vectors are filtered out
|
103
118
|
}
|
104
119
|
|
105
120
|
// loop. Compiler should get rid of unneeded ops
|
@@ -157,7 +172,8 @@ struct HeapWithBucketsForHamming32<
|
|
157
172
|
const auto value = min_distances_scalar[j8];
|
158
173
|
const auto index = min_indices_scalar[j8];
|
159
174
|
|
160
|
-
if (
|
175
|
+
if (value < std::numeric_limits<int32_t>::max() &&
|
176
|
+
C::cmp2(bh_val[0], value, bh_ids[0], index)) {
|
161
177
|
heap_replace_top<C>(
|
162
178
|
k, bh_val, bh_ids, value, index);
|
163
179
|
}
|
@@ -168,11 +184,13 @@ struct HeapWithBucketsForHamming32<
|
|
168
184
|
// process leftovers
|
169
185
|
for (uint32_t ip = nb; ip < n_per_beam; ip++) {
|
170
186
|
const auto index = ip + n_per_beam * beam_index;
|
171
|
-
|
172
|
-
|
187
|
+
if (!sel || sel->is_member(index)) {
|
188
|
+
const auto value =
|
189
|
+
hc.hamming(binary_vectors + (index)*code_size);
|
173
190
|
|
174
|
-
|
175
|
-
|
191
|
+
if (C::cmp(bh_val[0], value)) {
|
192
|
+
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
|
193
|
+
}
|
176
194
|
}
|
177
195
|
}
|
178
196
|
}
|
@@ -30,6 +30,7 @@
|
|
30
30
|
|
31
31
|
#include <faiss/impl/AuxIndexStructures.h>
|
32
32
|
#include <faiss/impl/FaissAssert.h>
|
33
|
+
#include <faiss/impl/IDSelector.h>
|
33
34
|
#include <faiss/utils/Heap.h>
|
34
35
|
#include <faiss/utils/approx_topk_hamming/approx_topk_hamming.h>
|
35
36
|
#include <faiss/utils/utils.h>
|
@@ -62,15 +63,15 @@ void hammings(
|
|
62
63
|
const uint64_t* __restrict bs2,
|
63
64
|
size_t n1,
|
64
65
|
size_t n2,
|
65
|
-
size_t
|
66
|
+
size_t nbits,
|
66
67
|
hamdis_t* __restrict dis) {
|
67
68
|
size_t i, j;
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
for (j = 0; j < n2; j
|
73
|
-
|
69
|
+
const size_t nwords = nbits / 64;
|
70
|
+
for (i = 0; i < n1; i++) {
|
71
|
+
const uint64_t* __restrict bs1_ = bs1 + i * nwords;
|
72
|
+
hamdis_t* __restrict dis_ = dis + i * n2;
|
73
|
+
for (j = 0; j < n2; j++)
|
74
|
+
dis_[j] = hamming(bs1_, bs2 + j * nwords, nwords);
|
74
75
|
}
|
75
76
|
}
|
76
77
|
|
@@ -171,7 +172,8 @@ void hammings_knn_hc(
|
|
171
172
|
size_t n2,
|
172
173
|
bool order = true,
|
173
174
|
bool init_heap = true,
|
174
|
-
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK
|
175
|
+
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
|
176
|
+
const faiss::IDSelector* sel = nullptr) {
|
175
177
|
size_t k = ha->k;
|
176
178
|
if (init_heap)
|
177
179
|
ha->heapify();
|
@@ -204,7 +206,7 @@ void hammings_knn_hc(
|
|
204
206
|
NB, \
|
205
207
|
BD, \
|
206
208
|
HammingComputer>:: \
|
207
|
-
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_);
|
209
|
+
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_, sel); \
|
208
210
|
break;
|
209
211
|
|
210
212
|
switch (approx_topk_mode) {
|
@@ -214,6 +216,9 @@ void hammings_knn_hc(
|
|
214
216
|
HANDLE_APPROX(32, 2)
|
215
217
|
default: {
|
216
218
|
for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
|
219
|
+
if (sel && !sel->is_member(j)) {
|
220
|
+
continue;
|
221
|
+
}
|
217
222
|
dis = hc.hamming(bs2_);
|
218
223
|
if (dis < bh_val_[0]) {
|
219
224
|
faiss::maxheap_replace_top<hamdis_t>(
|
@@ -238,7 +243,8 @@ void hammings_knn_mc(
|
|
238
243
|
size_t nb,
|
239
244
|
size_t k,
|
240
245
|
int32_t* __restrict distances,
|
241
|
-
int64_t* __restrict labels
|
246
|
+
int64_t* __restrict labels,
|
247
|
+
const faiss::IDSelector* sel) {
|
242
248
|
const int nBuckets = bytes_per_code * 8 + 1;
|
243
249
|
std::vector<int> all_counters(na * nBuckets, 0);
|
244
250
|
std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
|
@@ -259,7 +265,9 @@ void hammings_knn_mc(
|
|
259
265
|
#pragma omp parallel for
|
260
266
|
for (int64_t i = 0; i < na; ++i) {
|
261
267
|
for (size_t j = j0; j < j1; ++j) {
|
262
|
-
|
268
|
+
if (!sel || sel->is_member(j)) {
|
269
|
+
cs[i].update_counter(b + j * bytes_per_code, j);
|
270
|
+
}
|
263
271
|
}
|
264
272
|
}
|
265
273
|
}
|
@@ -291,7 +299,8 @@ void hamming_range_search(
|
|
291
299
|
size_t nb,
|
292
300
|
int radius,
|
293
301
|
size_t code_size,
|
294
|
-
RangeSearchResult* res
|
302
|
+
RangeSearchResult* res,
|
303
|
+
const faiss::IDSelector* sel) {
|
295
304
|
#pragma omp parallel
|
296
305
|
{
|
297
306
|
RangeSearchPartialResult pres(res);
|
@@ -303,9 +312,11 @@ void hamming_range_search(
|
|
303
312
|
RangeQueryResult& qres = pres.new_result(i);
|
304
313
|
|
305
314
|
for (size_t j = 0; j < nb; j++) {
|
306
|
-
|
307
|
-
|
308
|
-
|
315
|
+
if (!sel || sel->is_member(j)) {
|
316
|
+
int dis = hc.hamming(yi);
|
317
|
+
if (dis < radius) {
|
318
|
+
qres.add(dis, j);
|
319
|
+
}
|
309
320
|
}
|
310
321
|
yi += code_size;
|
311
322
|
}
|
@@ -489,10 +500,21 @@ void hammings_knn_hc(
|
|
489
500
|
size_t nb,
|
490
501
|
size_t ncodes,
|
491
502
|
int order,
|
492
|
-
ApproxTopK_mode_t approx_topk_mode
|
503
|
+
ApproxTopK_mode_t approx_topk_mode,
|
504
|
+
const faiss::IDSelector* sel) {
|
493
505
|
Run_hammings_knn_hc r;
|
494
506
|
dispatch_HammingComputer(
|
495
|
-
ncodes,
|
507
|
+
ncodes,
|
508
|
+
r,
|
509
|
+
ncodes,
|
510
|
+
ha,
|
511
|
+
a,
|
512
|
+
b,
|
513
|
+
nb,
|
514
|
+
order,
|
515
|
+
true,
|
516
|
+
approx_topk_mode,
|
517
|
+
sel);
|
496
518
|
}
|
497
519
|
|
498
520
|
void hammings_knn_mc(
|
@@ -503,10 +525,11 @@ void hammings_knn_mc(
|
|
503
525
|
size_t k,
|
504
526
|
size_t ncodes,
|
505
527
|
int32_t* __restrict distances,
|
506
|
-
int64_t* __restrict labels
|
528
|
+
int64_t* __restrict labels,
|
529
|
+
const faiss::IDSelector* sel) {
|
507
530
|
Run_hammings_knn_mc r;
|
508
531
|
dispatch_HammingComputer(
|
509
|
-
ncodes, r, ncodes, a, b, na, nb, k, distances, labels);
|
532
|
+
ncodes, r, ncodes, a, b, na, nb, k, distances, labels, sel);
|
510
533
|
}
|
511
534
|
|
512
535
|
void hamming_range_search(
|
@@ -516,10 +539,11 @@ void hamming_range_search(
|
|
516
539
|
size_t nb,
|
517
540
|
int radius,
|
518
541
|
size_t code_size,
|
519
|
-
RangeSearchResult* result
|
542
|
+
RangeSearchResult* result,
|
543
|
+
const faiss::IDSelector* sel) {
|
520
544
|
Run_hamming_range_search r;
|
521
545
|
dispatch_HammingComputer(
|
522
|
-
code_size, r, a, b, na, nb, radius, code_size, result);
|
546
|
+
code_size, r, a, b, na, nb, radius, code_size, result, sel);
|
523
547
|
}
|
524
548
|
|
525
549
|
/* Count number of matches given a max threshold */
|
@@ -27,6 +27,7 @@
|
|
27
27
|
|
28
28
|
#include <stdint.h>
|
29
29
|
|
30
|
+
#include <faiss/impl/IDSelector.h>
|
30
31
|
#include <faiss/impl/platform_macros.h>
|
31
32
|
#include <faiss/utils/Heap.h>
|
32
33
|
|
@@ -135,7 +136,8 @@ void hammings_knn_hc(
|
|
135
136
|
size_t nb,
|
136
137
|
size_t ncodes,
|
137
138
|
int ordered,
|
138
|
-
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK
|
139
|
+
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
|
140
|
+
const faiss::IDSelector* sel = nullptr);
|
139
141
|
|
140
142
|
/* Legacy alias to hammings_knn_hc. */
|
141
143
|
void hammings_knn(
|
@@ -166,7 +168,8 @@ void hammings_knn_mc(
|
|
166
168
|
size_t k,
|
167
169
|
size_t ncodes,
|
168
170
|
int32_t* distances,
|
169
|
-
int64_t* labels
|
171
|
+
int64_t* labels,
|
172
|
+
const faiss::IDSelector* sel = nullptr);
|
170
173
|
|
171
174
|
/** same as hammings_knn except we are doing a range search with radius */
|
172
175
|
void hamming_range_search(
|
@@ -176,7 +179,8 @@ void hamming_range_search(
|
|
176
179
|
size_t nb,
|
177
180
|
int radius,
|
178
181
|
size_t ncodes,
|
179
|
-
RangeSearchResult* result
|
182
|
+
RangeSearchResult* result,
|
183
|
+
const faiss::IDSelector* sel = nullptr);
|
180
184
|
|
181
185
|
/* Counting the number of matches or of cross-matches (without returning them)
|
182
186
|
For use with function that assume pre-allocated memory */
|
@@ -11,7 +11,7 @@
|
|
11
11
|
// AVX512 version
|
12
12
|
// The _mm512_popcnt_epi64 intrinsic is used to accelerate Hamming distance
|
13
13
|
// calculations in HammingComputerDefault and HammingComputer64. This intrinsic
|
14
|
-
// is not available in the default
|
14
|
+
// is not available in the default Faiss avx512 build mode but is only
|
15
15
|
// available in the avx512_spr build mode, which targets Intel(R) Sapphire
|
16
16
|
// Rapids.
|
17
17
|
|