faiss 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +18 -18
- data/README.md +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/Clustering.cpp +318 -53
- data/vendor/faiss/Clustering.h +39 -11
- data/vendor/faiss/DirectMap.cpp +267 -0
- data/vendor/faiss/DirectMap.h +120 -0
- data/vendor/faiss/IVFlib.cpp +24 -4
- data/vendor/faiss/IVFlib.h +4 -0
- data/vendor/faiss/Index.h +5 -24
- data/vendor/faiss/Index2Layer.cpp +0 -1
- data/vendor/faiss/IndexBinary.h +7 -3
- data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
- data/vendor/faiss/IndexBinaryFlat.h +3 -0
- data/vendor/faiss/IndexBinaryHash.cpp +492 -0
- data/vendor/faiss/IndexBinaryHash.h +116 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
- data/vendor/faiss/IndexBinaryIVF.h +14 -4
- data/vendor/faiss/IndexFlat.h +2 -1
- data/vendor/faiss/IndexHNSW.cpp +68 -16
- data/vendor/faiss/IndexHNSW.h +3 -3
- data/vendor/faiss/IndexIVF.cpp +72 -76
- data/vendor/faiss/IndexIVF.h +24 -5
- data/vendor/faiss/IndexIVFFlat.cpp +19 -54
- data/vendor/faiss/IndexIVFFlat.h +1 -11
- data/vendor/faiss/IndexIVFPQ.cpp +49 -26
- data/vendor/faiss/IndexIVFPQ.h +9 -10
- data/vendor/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
- data/vendor/faiss/IndexLSH.h +4 -1
- data/vendor/faiss/IndexPreTransform.cpp +0 -1
- data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
- data/vendor/faiss/InvertedLists.cpp +0 -2
- data/vendor/faiss/MetaIndexes.cpp +0 -1
- data/vendor/faiss/MetricType.h +36 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
- data/vendor/faiss/c_api/Clustering_c.h +11 -5
- data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
- data/vendor/faiss/gpu/GpuDistance.h +93 -0
- data/vendor/faiss/gpu/GpuIndex.h +7 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
- data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
- data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
- data/vendor/faiss/impl/HNSW.cpp +0 -1
- data/vendor/faiss/impl/PolysemousTraining.h +5 -5
- data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
- data/vendor/faiss/impl/ProductQuantizer.h +42 -47
- data/vendor/faiss/impl/index_read.cpp +103 -7
- data/vendor/faiss/impl/index_write.cpp +101 -5
- data/vendor/faiss/impl/io.cpp +111 -1
- data/vendor/faiss/impl/io.h +38 -0
- data/vendor/faiss/index_factory.cpp +0 -1
- data/vendor/faiss/tests/test_merge.cpp +0 -1
- data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
- data/vendor/faiss/utils/distances.cpp +4 -5
- data/vendor/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/utils/hamming.cpp +85 -3
- data/vendor/faiss/utils/hamming.h +20 -0
- data/vendor/faiss/utils/utils.cpp +0 -96
- data/vendor/faiss/utils/utils.h +0 -15
- metadata +11 -3
- data/lib/faiss/ext.bundle +0 -0
data/vendor/faiss/IVFlib.cpp
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
|
14
14
|
#include <faiss/IndexPreTransform.h>
|
15
15
|
#include <faiss/impl/FaissAssert.h>
|
16
|
+
#include <faiss/MetaIndexes.h>
|
16
17
|
|
17
18
|
|
18
19
|
|
@@ -56,17 +57,35 @@ void check_compatible_for_merge (const Index * index0,
|
|
56
57
|
|
57
58
|
}
|
58
59
|
|
59
|
-
const IndexIVF *
|
60
|
+
const IndexIVF * try_extract_index_ivf (const Index * index)
|
60
61
|
{
|
61
62
|
if (auto *pt =
|
62
63
|
dynamic_cast<const IndexPreTransform *>(index)) {
|
63
64
|
index = pt->index;
|
64
65
|
}
|
65
66
|
|
67
|
+
if (auto *idmap =
|
68
|
+
dynamic_cast<const IndexIDMap *>(index)) {
|
69
|
+
index = idmap->index;
|
70
|
+
}
|
71
|
+
if (auto *idmap =
|
72
|
+
dynamic_cast<const IndexIDMap2 *>(index)) {
|
73
|
+
index = idmap->index;
|
74
|
+
}
|
75
|
+
|
66
76
|
auto *ivf = dynamic_cast<const IndexIVF *>(index);
|
67
77
|
|
68
|
-
|
78
|
+
return ivf;
|
79
|
+
}
|
69
80
|
|
81
|
+
IndexIVF * try_extract_index_ivf (Index * index) {
|
82
|
+
return const_cast<IndexIVF*> (try_extract_index_ivf ((const Index*)(index)));
|
83
|
+
}
|
84
|
+
|
85
|
+
const IndexIVF * extract_index_ivf (const Index * index)
|
86
|
+
{
|
87
|
+
const IndexIVF *ivf = try_extract_index_ivf (index);
|
88
|
+
FAISS_THROW_IF_NOT (ivf);
|
70
89
|
return ivf;
|
71
90
|
}
|
72
91
|
|
@@ -74,6 +93,7 @@ IndexIVF * extract_index_ivf (Index * index) {
|
|
74
93
|
return const_cast<IndexIVF*> (extract_index_ivf ((const Index*)(index)));
|
75
94
|
}
|
76
95
|
|
96
|
+
|
77
97
|
void merge_into(faiss::Index *index0, faiss::Index *index1, bool shift_ids) {
|
78
98
|
|
79
99
|
check_compatible_for_merge (index0, index1);
|
@@ -146,8 +166,8 @@ void search_and_return_centroids(faiss::Index *index,
|
|
146
166
|
if (result_centroid_ids)
|
147
167
|
result_centroid_ids[i] = -1;
|
148
168
|
} else {
|
149
|
-
long list_no = label
|
150
|
-
long list_index = label
|
169
|
+
long list_no = lo_listno (label);
|
170
|
+
long list_index = lo_offset (label);
|
151
171
|
if (result_centroid_ids)
|
152
172
|
result_centroid_ids[i] = list_no;
|
153
173
|
labels[i] = index_ivf->invlists->get_single_id(list_no, list_index);
|
data/vendor/faiss/IVFlib.h
CHANGED
@@ -35,6 +35,10 @@ void check_compatible_for_merge (const Index * index1,
|
|
35
35
|
const IndexIVF * extract_index_ivf (const Index * index);
|
36
36
|
IndexIVF * extract_index_ivf (Index * index);
|
37
37
|
|
38
|
+
/// same as above but returns nullptr instead of throwing on failure
|
39
|
+
const IndexIVF * try_extract_index_ivf (const Index * index);
|
40
|
+
IndexIVF * try_extract_index_ivf (Index * index);
|
41
|
+
|
38
42
|
/** Merge index1 into index0. Works on IndexIVF's and IndexIVF's
|
39
43
|
* embedded in a IndexPreTransform. On output, the index1 is empty.
|
40
44
|
*
|
data/vendor/faiss/Index.h
CHANGED
@@ -10,7 +10,7 @@
|
|
10
10
|
#ifndef FAISS_INDEX_H
|
11
11
|
#define FAISS_INDEX_H
|
12
12
|
|
13
|
-
|
13
|
+
#include <faiss/MetricType.h>
|
14
14
|
#include <cstdio>
|
15
15
|
#include <typeinfo>
|
16
16
|
#include <string>
|
@@ -18,7 +18,7 @@
|
|
18
18
|
|
19
19
|
#define FAISS_VERSION_MAJOR 1
|
20
20
|
#define FAISS_VERSION_MINOR 6
|
21
|
-
#define FAISS_VERSION_PATCH
|
21
|
+
#define FAISS_VERSION_PATCH 3
|
22
22
|
|
23
23
|
/**
|
24
24
|
* @namespace faiss
|
@@ -39,34 +39,15 @@
|
|
39
39
|
|
40
40
|
namespace faiss {
|
41
41
|
|
42
|
-
|
43
|
-
/// Some algorithms support both an inner product version and a L2 search version.
|
44
|
-
enum MetricType {
|
45
|
-
METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
|
46
|
-
METRIC_L2 = 1, ///< squared L2 search
|
47
|
-
METRIC_L1, ///< L1 (aka cityblock)
|
48
|
-
METRIC_Linf, ///< infinity distance
|
49
|
-
METRIC_Lp, ///< L_p distance, p is given by metric_arg
|
50
|
-
|
51
|
-
/// some additional metrics defined in scipy.spatial.distance
|
52
|
-
METRIC_Canberra = 20,
|
53
|
-
METRIC_BrayCurtis,
|
54
|
-
METRIC_JensenShannon,
|
55
|
-
|
56
|
-
};
|
57
|
-
|
58
|
-
|
59
42
|
/// Forward declarations see AuxIndexStructures.h
|
60
43
|
struct IDSelector;
|
61
44
|
struct RangeSearchResult;
|
62
45
|
struct DistanceComputer;
|
63
46
|
|
64
|
-
/** Abstract structure for an index
|
65
|
-
*
|
66
|
-
* Supports adding vertices and searching them.
|
47
|
+
/** Abstract structure for an index, supports adding vectors and searching them.
|
67
48
|
*
|
68
|
-
*
|
69
|
-
*
|
49
|
+
* All vectors provided at add or search time are 32-bit float arrays,
|
50
|
+
* although the internal representation may vary.
|
70
51
|
*/
|
71
52
|
struct Index {
|
72
53
|
using idx_t = int64_t; ///< all indices are this type
|
data/vendor/faiss/IndexBinary.h
CHANGED
@@ -99,9 +99,13 @@ struct IndexBinary {
|
|
99
99
|
|
100
100
|
/** Query n vectors of dimension d to the index.
|
101
101
|
*
|
102
|
-
* return all vectors with distance < radius. Note that many
|
103
|
-
*
|
104
|
-
*
|
102
|
+
* return all vectors with distance < radius. Note that many indexes
|
103
|
+
* do not implement the range_search (only the k-NN search is
|
104
|
+
* mandatory). The distances are converted to float to reuse the
|
105
|
+
* RangeSearchResult structure, but they are integer. By convention,
|
106
|
+
* only distances < radius (strict comparison) are returned,
|
107
|
+
* ie. radius = 0 does not return any result and 1 returns only
|
108
|
+
* exact same vectors.
|
105
109
|
*
|
106
110
|
* @param x input vectors to search, size n * d / 8
|
107
111
|
* @param radius search radius
|
@@ -79,5 +79,10 @@ void IndexBinaryFlat::reconstruct(idx_t key, uint8_t *recons) const {
|
|
79
79
|
memcpy(recons, &(xb[code_size * key]), sizeof(*recons) * code_size);
|
80
80
|
}
|
81
81
|
|
82
|
+
void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
|
83
|
+
RangeSearchResult *result) const
|
84
|
+
{
|
85
|
+
hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
|
86
|
+
}
|
82
87
|
|
83
88
|
} // namespace faiss
|
@@ -38,6 +38,9 @@ struct IndexBinaryFlat : IndexBinary {
|
|
38
38
|
void search(idx_t n, const uint8_t *x, idx_t k,
|
39
39
|
int32_t *distances, idx_t *labels) const override;
|
40
40
|
|
41
|
+
void range_search(idx_t n, const uint8_t *x, int radius,
|
42
|
+
RangeSearchResult *result) const override;
|
43
|
+
|
41
44
|
void reconstruct(idx_t key, uint8_t *recons) const override;
|
42
45
|
|
43
46
|
/** Remove some ids. Note that because of the indexing structure,
|
@@ -0,0 +1,492 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// Copyright 2004-present Facebook. All Rights Reserved
|
9
|
+
// -*- c++ -*-
|
10
|
+
|
11
|
+
#include <faiss/IndexBinaryHash.h>
|
12
|
+
|
13
|
+
#include <cstdio>
|
14
|
+
#include <memory>
|
15
|
+
|
16
|
+
#include <faiss/utils/hamming.h>
|
17
|
+
#include <faiss/utils/utils.h>
|
18
|
+
|
19
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
21
|
+
|
22
|
+
|
23
|
+
namespace faiss {
|
24
|
+
|
25
|
+
void IndexBinaryHash::InvertedList::add (
|
26
|
+
idx_t id, size_t code_size, const uint8_t *code)
|
27
|
+
{
|
28
|
+
ids.push_back(id);
|
29
|
+
vecs.insert(vecs.end(), code, code + code_size);
|
30
|
+
}
|
31
|
+
|
32
|
+
IndexBinaryHash::IndexBinaryHash(int d, int b):
|
33
|
+
IndexBinary(d), b(b), nflip(0)
|
34
|
+
{
|
35
|
+
is_trained = true;
|
36
|
+
}
|
37
|
+
|
38
|
+
IndexBinaryHash::IndexBinaryHash(): b(0), nflip(0)
|
39
|
+
{
|
40
|
+
is_trained = true;
|
41
|
+
}
|
42
|
+
|
43
|
+
void IndexBinaryHash::reset()
|
44
|
+
{
|
45
|
+
invlists.clear();
|
46
|
+
ntotal = 0;
|
47
|
+
}
|
48
|
+
|
49
|
+
|
50
|
+
void IndexBinaryHash::add(idx_t n, const uint8_t *x)
|
51
|
+
{
|
52
|
+
add_with_ids(n, x, nullptr);
|
53
|
+
}
|
54
|
+
|
55
|
+
void IndexBinaryHash::add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids)
|
56
|
+
{
|
57
|
+
uint64_t mask = ((uint64_t)1 << b) - 1;
|
58
|
+
// simplistic add function. Cannot really be parallelized.
|
59
|
+
|
60
|
+
for (idx_t i = 0; i < n; i++) {
|
61
|
+
idx_t id = xids ? xids[i] : ntotal + i;
|
62
|
+
const uint8_t * xi = x + i * code_size;
|
63
|
+
idx_t hash = *((uint64_t*)xi) & mask;
|
64
|
+
invlists[hash].add(id, code_size, xi);
|
65
|
+
}
|
66
|
+
ntotal += n;
|
67
|
+
}
|
68
|
+
|
69
|
+
namespace {
|
70
|
+
|
71
|
+
|
72
|
+
/** Enumerate all bit vectors of size nbit with up to maxflip 1s
|
73
|
+
* test in P127257851 P127258235
|
74
|
+
*/
|
75
|
+
struct FlipEnumerator {
|
76
|
+
int nbit, nflip, maxflip;
|
77
|
+
uint64_t mask, x;
|
78
|
+
|
79
|
+
FlipEnumerator (int nbit, int maxflip): nbit(nbit), maxflip(maxflip) {
|
80
|
+
nflip = 0;
|
81
|
+
mask = 0;
|
82
|
+
x = 0;
|
83
|
+
}
|
84
|
+
|
85
|
+
bool next() {
|
86
|
+
if (x == mask) {
|
87
|
+
if (nflip == maxflip) {
|
88
|
+
return false;
|
89
|
+
}
|
90
|
+
// increase Hamming radius
|
91
|
+
nflip++;
|
92
|
+
mask = (((uint64_t)1 << nflip) - 1);
|
93
|
+
x = mask << (nbit - nflip);
|
94
|
+
return true;
|
95
|
+
}
|
96
|
+
|
97
|
+
int i = __builtin_ctzll(x);
|
98
|
+
|
99
|
+
if (i > 0) {
|
100
|
+
x ^= (uint64_t)3 << (i - 1);
|
101
|
+
} else {
|
102
|
+
// nb of LSB 1s
|
103
|
+
int n1 = __builtin_ctzll(~x);
|
104
|
+
// clear them
|
105
|
+
x &= ((uint64_t)(-1) << n1);
|
106
|
+
int n2 = __builtin_ctzll(x);
|
107
|
+
x ^= (((uint64_t)1 << (n1 + 2)) - 1) << (n2 - n1 - 1);
|
108
|
+
}
|
109
|
+
return true;
|
110
|
+
}
|
111
|
+
|
112
|
+
};
|
113
|
+
|
114
|
+
using idx_t = Index::idx_t;
|
115
|
+
|
116
|
+
|
117
|
+
struct RangeSearchResults {
|
118
|
+
int radius;
|
119
|
+
RangeQueryResult &qres;
|
120
|
+
|
121
|
+
inline void add (float dis, idx_t id) {
|
122
|
+
if (dis < radius) {
|
123
|
+
qres.add (dis, id);
|
124
|
+
}
|
125
|
+
}
|
126
|
+
|
127
|
+
};
|
128
|
+
|
129
|
+
struct KnnSearchResults {
|
130
|
+
// heap params
|
131
|
+
idx_t k;
|
132
|
+
int32_t * heap_sim;
|
133
|
+
idx_t * heap_ids;
|
134
|
+
|
135
|
+
using C = CMax<int, idx_t>;
|
136
|
+
|
137
|
+
inline void add (float dis, idx_t id) {
|
138
|
+
if (dis < heap_sim[0]) {
|
139
|
+
heap_pop<C> (k, heap_sim, heap_ids);
|
140
|
+
heap_push<C> (k, heap_sim, heap_ids, dis, id);
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
};
|
145
|
+
|
146
|
+
template<class HammingComputer, class SearchResults>
|
147
|
+
void
|
148
|
+
search_single_query_template(const IndexBinaryHash & index, const uint8_t *q,
|
149
|
+
SearchResults &res,
|
150
|
+
size_t &n0, size_t &nlist, size_t &ndis)
|
151
|
+
{
|
152
|
+
size_t code_size = index.code_size;
|
153
|
+
uint64_t mask = ((uint64_t)1 << index.b) - 1;
|
154
|
+
uint64_t qhash = *((uint64_t*)q) & mask;
|
155
|
+
HammingComputer hc (q, code_size);
|
156
|
+
FlipEnumerator fe(index.b, index.nflip);
|
157
|
+
|
158
|
+
// loop over neighbors that are at most at nflip bits
|
159
|
+
do {
|
160
|
+
uint64_t hash = qhash ^ fe.x;
|
161
|
+
auto it = index.invlists.find (hash);
|
162
|
+
|
163
|
+
if (it == index.invlists.end()) {
|
164
|
+
continue;
|
165
|
+
}
|
166
|
+
|
167
|
+
const IndexBinaryHash::InvertedList &il = it->second;
|
168
|
+
|
169
|
+
size_t nv = il.ids.size();
|
170
|
+
|
171
|
+
if (nv == 0) {
|
172
|
+
n0++;
|
173
|
+
} else {
|
174
|
+
const uint8_t *codes = il.vecs.data();
|
175
|
+
for (size_t i = 0; i < nv; i++) {
|
176
|
+
int dis = hc.hamming (codes);
|
177
|
+
res.add(dis, il.ids[i]);
|
178
|
+
codes += code_size;
|
179
|
+
}
|
180
|
+
ndis += nv;
|
181
|
+
nlist++;
|
182
|
+
}
|
183
|
+
} while(fe.next());
|
184
|
+
}
|
185
|
+
|
186
|
+
template<class SearchResults>
|
187
|
+
void
|
188
|
+
search_single_query(const IndexBinaryHash & index, const uint8_t *q,
|
189
|
+
SearchResults &res,
|
190
|
+
size_t &n0, size_t &nlist, size_t &ndis)
|
191
|
+
{
|
192
|
+
#define HC(name) search_single_query_template<name>(index, q, res, n0, nlist, ndis);
|
193
|
+
switch(index.code_size) {
|
194
|
+
case 4: HC(HammingComputer4); break;
|
195
|
+
case 8: HC(HammingComputer8); break;
|
196
|
+
case 16: HC(HammingComputer16); break;
|
197
|
+
case 20: HC(HammingComputer20); break;
|
198
|
+
case 32: HC(HammingComputer32); break;
|
199
|
+
default:
|
200
|
+
if (index.code_size % 8 == 0) {
|
201
|
+
HC(HammingComputerM8);
|
202
|
+
} else {
|
203
|
+
HC(HammingComputerDefault);
|
204
|
+
}
|
205
|
+
}
|
206
|
+
#undef HC
|
207
|
+
}
|
208
|
+
|
209
|
+
|
210
|
+
} // anonymous namespace
|
211
|
+
|
212
|
+
|
213
|
+
|
214
|
+
void IndexBinaryHash::range_search(idx_t n, const uint8_t *x, int radius,
|
215
|
+
RangeSearchResult *result) const
|
216
|
+
{
|
217
|
+
|
218
|
+
size_t nlist = 0, ndis = 0, n0 = 0;
|
219
|
+
|
220
|
+
#pragma omp parallel if(n > 100) reduction(+: ndis, n0, nlist)
|
221
|
+
{
|
222
|
+
RangeSearchPartialResult pres (result);
|
223
|
+
|
224
|
+
#pragma omp for
|
225
|
+
for (size_t i = 0; i < n; i++) { // loop queries
|
226
|
+
RangeQueryResult & qres = pres.new_result (i);
|
227
|
+
RangeSearchResults res = {radius, qres};
|
228
|
+
const uint8_t *q = x + i * code_size;
|
229
|
+
|
230
|
+
search_single_query (*this, q, res, n0, nlist, ndis);
|
231
|
+
|
232
|
+
}
|
233
|
+
pres.finalize ();
|
234
|
+
}
|
235
|
+
indexBinaryHash_stats.nq += n;
|
236
|
+
indexBinaryHash_stats.n0 += n0;
|
237
|
+
indexBinaryHash_stats.nlist += nlist;
|
238
|
+
indexBinaryHash_stats.ndis += ndis;
|
239
|
+
}
|
240
|
+
|
241
|
+
void IndexBinaryHash::search(idx_t n, const uint8_t *x, idx_t k,
|
242
|
+
int32_t *distances, idx_t *labels) const
|
243
|
+
{
|
244
|
+
|
245
|
+
using HeapForL2 = CMax<int32_t, idx_t>;
|
246
|
+
size_t nlist = 0, ndis = 0, n0 = 0;
|
247
|
+
|
248
|
+
#pragma omp parallel for if(n > 100) reduction(+: nlist, ndis, n0)
|
249
|
+
for (size_t i = 0; i < n; i++) {
|
250
|
+
int32_t * simi = distances + k * i;
|
251
|
+
idx_t * idxi = labels + k * i;
|
252
|
+
|
253
|
+
heap_heapify<HeapForL2> (k, simi, idxi);
|
254
|
+
KnnSearchResults res = {k, simi, idxi};
|
255
|
+
const uint8_t *q = x + i * code_size;
|
256
|
+
|
257
|
+
search_single_query (*this, q, res, n0, nlist, ndis);
|
258
|
+
|
259
|
+
}
|
260
|
+
indexBinaryHash_stats.nq += n;
|
261
|
+
indexBinaryHash_stats.n0 += n0;
|
262
|
+
indexBinaryHash_stats.nlist += nlist;
|
263
|
+
indexBinaryHash_stats.ndis += ndis;
|
264
|
+
}
|
265
|
+
|
266
|
+
size_t IndexBinaryHash::hashtable_size() const
|
267
|
+
{
|
268
|
+
return invlists.size();
|
269
|
+
}
|
270
|
+
|
271
|
+
|
272
|
+
void IndexBinaryHash::display() const
|
273
|
+
{
|
274
|
+
for (auto it = invlists.begin(); it != invlists.end(); ++it) {
|
275
|
+
printf("%ld: [", it->first);
|
276
|
+
const std::vector<idx_t> & v = it->second.ids;
|
277
|
+
for (auto x: v) {
|
278
|
+
printf("%ld ", 0 + x);
|
279
|
+
}
|
280
|
+
printf("]\n");
|
281
|
+
|
282
|
+
}
|
283
|
+
}
|
284
|
+
|
285
|
+
|
286
|
+
void IndexBinaryHashStats::reset()
|
287
|
+
{
|
288
|
+
memset ((void*)this, 0, sizeof (*this));
|
289
|
+
}
|
290
|
+
|
291
|
+
IndexBinaryHashStats indexBinaryHash_stats;
|
292
|
+
|
293
|
+
/*******************************************************
|
294
|
+
* IndexBinaryMultiHash implementation
|
295
|
+
******************************************************/
|
296
|
+
|
297
|
+
|
298
|
+
IndexBinaryMultiHash::IndexBinaryMultiHash(int d, int nhash, int b):
|
299
|
+
IndexBinary(d),
|
300
|
+
storage(new IndexBinaryFlat(d)), own_fields(true),
|
301
|
+
maps(nhash), nhash(nhash), b(b), nflip(0)
|
302
|
+
{
|
303
|
+
FAISS_THROW_IF_NOT(nhash * b <= d);
|
304
|
+
}
|
305
|
+
|
306
|
+
IndexBinaryMultiHash::IndexBinaryMultiHash():
|
307
|
+
storage(nullptr), own_fields(true),
|
308
|
+
nhash(0), b(0), nflip(0)
|
309
|
+
{}
|
310
|
+
|
311
|
+
IndexBinaryMultiHash::~IndexBinaryMultiHash()
|
312
|
+
{
|
313
|
+
if (own_fields) {
|
314
|
+
delete storage;
|
315
|
+
}
|
316
|
+
}
|
317
|
+
|
318
|
+
|
319
|
+
void IndexBinaryMultiHash::reset()
|
320
|
+
{
|
321
|
+
storage->reset();
|
322
|
+
ntotal = 0;
|
323
|
+
for(auto map: maps) {
|
324
|
+
map.clear();
|
325
|
+
}
|
326
|
+
}
|
327
|
+
|
328
|
+
void IndexBinaryMultiHash::add(idx_t n, const uint8_t *x)
|
329
|
+
{
|
330
|
+
storage->add(n, x);
|
331
|
+
// populate maps
|
332
|
+
uint64_t mask = ((uint64_t)1 << b) - 1;
|
333
|
+
|
334
|
+
for(idx_t i = 0; i < n; i++) {
|
335
|
+
const uint8_t *xi = x + i * code_size;
|
336
|
+
int ho = 0;
|
337
|
+
for(int h = 0; h < nhash; h++) {
|
338
|
+
uint64_t hash = *(uint64_t*)(xi + (ho >> 3)) >> (ho & 7);
|
339
|
+
hash &= mask;
|
340
|
+
maps[h][hash].push_back(i + ntotal);
|
341
|
+
ho += b;
|
342
|
+
}
|
343
|
+
}
|
344
|
+
ntotal += n;
|
345
|
+
}
|
346
|
+
|
347
|
+
|
348
|
+
namespace {
|
349
|
+
|
350
|
+
template <class HammingComputer, class SearchResults>
|
351
|
+
static
|
352
|
+
void verify_shortlist(
|
353
|
+
const IndexBinaryFlat & index,
|
354
|
+
const uint8_t * q,
|
355
|
+
const std::unordered_set<Index::idx_t> & shortlist,
|
356
|
+
SearchResults &res)
|
357
|
+
{
|
358
|
+
size_t code_size = index.code_size;
|
359
|
+
size_t nlist = 0, ndis = 0, n0 = 0;
|
360
|
+
|
361
|
+
HammingComputer hc (q, code_size);
|
362
|
+
const uint8_t *codes = index.xb.data();
|
363
|
+
|
364
|
+
for (auto i: shortlist) {
|
365
|
+
int dis = hc.hamming (codes + i * code_size);
|
366
|
+
res.add(dis, i);
|
367
|
+
}
|
368
|
+
}
|
369
|
+
|
370
|
+
template<class SearchResults>
|
371
|
+
void
|
372
|
+
search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi,
|
373
|
+
SearchResults &res,
|
374
|
+
size_t &n0, size_t &nlist, size_t &ndis)
|
375
|
+
{
|
376
|
+
|
377
|
+
std::unordered_set<idx_t> shortlist;
|
378
|
+
int b = index.b;
|
379
|
+
uint64_t mask = ((uint64_t)1 << b) - 1;
|
380
|
+
|
381
|
+
int ho = 0;
|
382
|
+
for(int h = 0; h < index.nhash; h++) {
|
383
|
+
uint64_t qhash = *(uint64_t*)(xi + (ho >> 3)) >> (ho & 7);
|
384
|
+
qhash &= mask;
|
385
|
+
const IndexBinaryMultiHash::Map & map = index.maps[h];
|
386
|
+
|
387
|
+
FlipEnumerator fe(index.b, index.nflip);
|
388
|
+
// loop over neighbors that are at most at nflip bits
|
389
|
+
do {
|
390
|
+
uint64_t hash = qhash ^ fe.x;
|
391
|
+
auto it = map.find (hash);
|
392
|
+
|
393
|
+
if (it != map.end()) {
|
394
|
+
const std::vector<idx_t> & v = it->second;
|
395
|
+
for (auto i: v) {
|
396
|
+
shortlist.insert(i);
|
397
|
+
}
|
398
|
+
nlist++;
|
399
|
+
} else {
|
400
|
+
n0++;
|
401
|
+
}
|
402
|
+
} while(fe.next());
|
403
|
+
|
404
|
+
ho += b;
|
405
|
+
}
|
406
|
+
ndis += shortlist.size();
|
407
|
+
|
408
|
+
// verify shortlist
|
409
|
+
|
410
|
+
#define HC(name) verify_shortlist<name> (*index.storage, xi, shortlist, res)
|
411
|
+
switch(index.code_size) {
|
412
|
+
case 4: HC(HammingComputer4); break;
|
413
|
+
case 8: HC(HammingComputer8); break;
|
414
|
+
case 16: HC(HammingComputer16); break;
|
415
|
+
case 20: HC(HammingComputer20); break;
|
416
|
+
case 32: HC(HammingComputer32); break;
|
417
|
+
default:
|
418
|
+
if (index.code_size % 8 == 0) {
|
419
|
+
HC(HammingComputerM8);
|
420
|
+
} else {
|
421
|
+
HC(HammingComputerDefault);
|
422
|
+
}
|
423
|
+
}
|
424
|
+
#undef HC
|
425
|
+
}
|
426
|
+
|
427
|
+
} // anonymous namespace
|
428
|
+
|
429
|
+
void IndexBinaryMultiHash::range_search(idx_t n, const uint8_t *x, int radius,
|
430
|
+
RangeSearchResult *result) const
|
431
|
+
{
|
432
|
+
|
433
|
+
size_t nlist = 0, ndis = 0, n0 = 0;
|
434
|
+
|
435
|
+
#pragma omp parallel if(n > 100) reduction(+: ndis, n0, nlist)
|
436
|
+
{
|
437
|
+
RangeSearchPartialResult pres (result);
|
438
|
+
|
439
|
+
#pragma omp for
|
440
|
+
for (size_t i = 0; i < n; i++) { // loop queries
|
441
|
+
RangeQueryResult & qres = pres.new_result (i);
|
442
|
+
RangeSearchResults res = {radius, qres};
|
443
|
+
const uint8_t *q = x + i * code_size;
|
444
|
+
|
445
|
+
search_1_query_multihash (*this, q, res, n0, nlist, ndis);
|
446
|
+
|
447
|
+
}
|
448
|
+
pres.finalize ();
|
449
|
+
}
|
450
|
+
indexBinaryHash_stats.nq += n;
|
451
|
+
indexBinaryHash_stats.n0 += n0;
|
452
|
+
indexBinaryHash_stats.nlist += nlist;
|
453
|
+
indexBinaryHash_stats.ndis += ndis;
|
454
|
+
}
|
455
|
+
|
456
|
+
void IndexBinaryMultiHash::search(idx_t n, const uint8_t *x, idx_t k,
|
457
|
+
int32_t *distances, idx_t *labels) const
|
458
|
+
{
|
459
|
+
|
460
|
+
using HeapForL2 = CMax<int32_t, idx_t>;
|
461
|
+
size_t nlist = 0, ndis = 0, n0 = 0;
|
462
|
+
|
463
|
+
#pragma omp parallel for if(n > 100) reduction(+: nlist, ndis, n0)
|
464
|
+
for (size_t i = 0; i < n; i++) {
|
465
|
+
int32_t * simi = distances + k * i;
|
466
|
+
idx_t * idxi = labels + k * i;
|
467
|
+
|
468
|
+
heap_heapify<HeapForL2> (k, simi, idxi);
|
469
|
+
KnnSearchResults res = {k, simi, idxi};
|
470
|
+
const uint8_t *q = x + i * code_size;
|
471
|
+
|
472
|
+
search_1_query_multihash (*this, q, res, n0, nlist, ndis);
|
473
|
+
|
474
|
+
}
|
475
|
+
indexBinaryHash_stats.nq += n;
|
476
|
+
indexBinaryHash_stats.n0 += n0;
|
477
|
+
indexBinaryHash_stats.nlist += nlist;
|
478
|
+
indexBinaryHash_stats.ndis += ndis;
|
479
|
+
}
|
480
|
+
|
481
|
+
size_t IndexBinaryMultiHash::hashtable_size() const
|
482
|
+
{
|
483
|
+
size_t tot = 0;
|
484
|
+
for (auto map: maps) {
|
485
|
+
tot += map.size();
|
486
|
+
}
|
487
|
+
|
488
|
+
return tot;
|
489
|
+
}
|
490
|
+
|
491
|
+
|
492
|
+
}
|