faiss 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +18 -18
- data/README.md +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/Clustering.cpp +318 -53
- data/vendor/faiss/Clustering.h +39 -11
- data/vendor/faiss/DirectMap.cpp +267 -0
- data/vendor/faiss/DirectMap.h +120 -0
- data/vendor/faiss/IVFlib.cpp +24 -4
- data/vendor/faiss/IVFlib.h +4 -0
- data/vendor/faiss/Index.h +5 -24
- data/vendor/faiss/Index2Layer.cpp +0 -1
- data/vendor/faiss/IndexBinary.h +7 -3
- data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
- data/vendor/faiss/IndexBinaryFlat.h +3 -0
- data/vendor/faiss/IndexBinaryHash.cpp +492 -0
- data/vendor/faiss/IndexBinaryHash.h +116 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
- data/vendor/faiss/IndexBinaryIVF.h +14 -4
- data/vendor/faiss/IndexFlat.h +2 -1
- data/vendor/faiss/IndexHNSW.cpp +68 -16
- data/vendor/faiss/IndexHNSW.h +3 -3
- data/vendor/faiss/IndexIVF.cpp +72 -76
- data/vendor/faiss/IndexIVF.h +24 -5
- data/vendor/faiss/IndexIVFFlat.cpp +19 -54
- data/vendor/faiss/IndexIVFFlat.h +1 -11
- data/vendor/faiss/IndexIVFPQ.cpp +49 -26
- data/vendor/faiss/IndexIVFPQ.h +9 -10
- data/vendor/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
- data/vendor/faiss/IndexLSH.h +4 -1
- data/vendor/faiss/IndexPreTransform.cpp +0 -1
- data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
- data/vendor/faiss/InvertedLists.cpp +0 -2
- data/vendor/faiss/MetaIndexes.cpp +0 -1
- data/vendor/faiss/MetricType.h +36 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
- data/vendor/faiss/c_api/Clustering_c.h +11 -5
- data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
- data/vendor/faiss/gpu/GpuDistance.h +93 -0
- data/vendor/faiss/gpu/GpuIndex.h +7 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
- data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
- data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
- data/vendor/faiss/impl/HNSW.cpp +0 -1
- data/vendor/faiss/impl/PolysemousTraining.h +5 -5
- data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
- data/vendor/faiss/impl/ProductQuantizer.h +42 -47
- data/vendor/faiss/impl/index_read.cpp +103 -7
- data/vendor/faiss/impl/index_write.cpp +101 -5
- data/vendor/faiss/impl/io.cpp +111 -1
- data/vendor/faiss/impl/io.h +38 -0
- data/vendor/faiss/index_factory.cpp +0 -1
- data/vendor/faiss/tests/test_merge.cpp +0 -1
- data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
- data/vendor/faiss/utils/distances.cpp +4 -5
- data/vendor/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/utils/hamming.cpp +85 -3
- data/vendor/faiss/utils/hamming.h +20 -0
- data/vendor/faiss/utils/utils.cpp +0 -96
- data/vendor/faiss/utils/utils.h +0 -15
- metadata +11 -3
- data/lib/faiss/ext.bundle +0 -0
@@ -46,8 +46,7 @@ struct IndexBinaryIVF : IndexBinary {
|
|
46
46
|
bool use_heap = true;
|
47
47
|
|
48
48
|
/// map for direct access to the elements. Enables reconstruct().
|
49
|
-
|
50
|
-
std::vector<idx_t> direct_map;
|
49
|
+
DirectMap direct_map;
|
51
50
|
|
52
51
|
IndexBinary *quantizer; ///< quantizer that maps vectors to inverted lists
|
53
52
|
size_t nlist; ///< number of possible key values
|
@@ -110,8 +109,11 @@ struct IndexBinaryIVF : IndexBinary {
|
|
110
109
|
bool store_pairs=false) const;
|
111
110
|
|
112
111
|
/** assign the vectors, then call search_preassign */
|
113
|
-
|
114
|
-
|
112
|
+
void search(idx_t n, const uint8_t *x, idx_t k,
|
113
|
+
int32_t *distances, idx_t *labels) const override;
|
114
|
+
|
115
|
+
void range_search(idx_t n, const uint8_t *x, int radius,
|
116
|
+
RangeSearchResult *result) const override;
|
115
117
|
|
116
118
|
void reconstruct(idx_t key, uint8_t *recons) const override;
|
117
119
|
|
@@ -168,6 +170,8 @@ struct IndexBinaryIVF : IndexBinary {
|
|
168
170
|
*/
|
169
171
|
void make_direct_map(bool new_maintain_direct_map=true);
|
170
172
|
|
173
|
+
void set_direct_map_type (DirectMap::Type type);
|
174
|
+
|
171
175
|
void replace_invlists(InvertedLists *il, bool own=false);
|
172
176
|
};
|
173
177
|
|
@@ -201,6 +205,12 @@ struct BinaryInvertedListScanner {
|
|
201
205
|
int32_t *distances, idx_t *labels,
|
202
206
|
size_t k) const = 0;
|
203
207
|
|
208
|
+
virtual void scan_codes_range (size_t n,
|
209
|
+
const uint8_t *codes,
|
210
|
+
const idx_t *ids,
|
211
|
+
int radius,
|
212
|
+
RangeQueryResult &result) const = 0;
|
213
|
+
|
204
214
|
virtual ~BinaryInvertedListScanner () {}
|
205
215
|
|
206
216
|
};
|
data/vendor/faiss/IndexFlat.h
CHANGED
@@ -19,6 +19,7 @@ namespace faiss {
|
|
19
19
|
|
20
20
|
/** Index that stores the full vectors and performs exhaustive search */
|
21
21
|
struct IndexFlat: Index {
|
22
|
+
|
22
23
|
/// database vectors, size ntotal * d
|
23
24
|
std::vector<float> xb;
|
24
25
|
|
@@ -144,7 +145,7 @@ struct IndexRefineFlat: Index {
|
|
144
145
|
};
|
145
146
|
|
146
147
|
|
147
|
-
/// optimized version for 1D "vectors"
|
148
|
+
/// optimized version for 1D "vectors".
|
148
149
|
struct IndexFlat1D:IndexFlatL2 {
|
149
150
|
bool continuous_update; ///< is the permutation updated continuously?
|
150
151
|
|
data/vendor/faiss/IndexHNSW.cpp
CHANGED
@@ -26,7 +26,6 @@
|
|
26
26
|
#include <stdint.h>
|
27
27
|
|
28
28
|
#ifdef __SSE__
|
29
|
-
#include <immintrin.h>
|
30
29
|
#endif
|
31
30
|
|
32
31
|
#include <faiss/utils/distances.h>
|
@@ -55,7 +54,6 @@ namespace faiss {
|
|
55
54
|
using idx_t = Index::idx_t;
|
56
55
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
57
56
|
using storage_idx_t = HNSW::storage_idx_t;
|
58
|
-
using NodeDistCloser = HNSW::NodeDistCloser;
|
59
57
|
using NodeDistFarther = HNSW::NodeDistFarther;
|
60
58
|
|
61
59
|
HNSWStats hnsw_stats;
|
@@ -67,6 +65,50 @@ HNSWStats hnsw_stats;
|
|
67
65
|
namespace {
|
68
66
|
|
69
67
|
|
68
|
+
/* Wrap the distance computer into one that negates the
|
69
|
+
distances. This makes supporting INNER_PRODUCE search easier */
|
70
|
+
|
71
|
+
struct NegativeDistanceComputer: DistanceComputer {
|
72
|
+
|
73
|
+
/// owned by this
|
74
|
+
DistanceComputer *basedis;
|
75
|
+
|
76
|
+
explicit NegativeDistanceComputer(DistanceComputer *basedis):
|
77
|
+
basedis(basedis)
|
78
|
+
{}
|
79
|
+
|
80
|
+
void set_query(const float *x) override {
|
81
|
+
basedis->set_query(x);
|
82
|
+
}
|
83
|
+
|
84
|
+
/// compute distance of vector i to current query
|
85
|
+
float operator () (idx_t i) override {
|
86
|
+
return -(*basedis)(i);
|
87
|
+
}
|
88
|
+
|
89
|
+
/// compute distance between two stored vectors
|
90
|
+
float symmetric_dis (idx_t i, idx_t j) override {
|
91
|
+
return -basedis->symmetric_dis(i, j);
|
92
|
+
}
|
93
|
+
|
94
|
+
virtual ~NegativeDistanceComputer ()
|
95
|
+
{
|
96
|
+
delete basedis;
|
97
|
+
}
|
98
|
+
|
99
|
+
};
|
100
|
+
|
101
|
+
DistanceComputer *storage_distance_computer(const Index *storage)
|
102
|
+
{
|
103
|
+
if (storage->metric_type == METRIC_INNER_PRODUCT) {
|
104
|
+
return new NegativeDistanceComputer(storage->get_distance_computer());
|
105
|
+
} else {
|
106
|
+
return storage->get_distance_computer();
|
107
|
+
}
|
108
|
+
}
|
109
|
+
|
110
|
+
|
111
|
+
|
70
112
|
void hnsw_add_vertices(IndexHNSW &index_hnsw,
|
71
113
|
size_t n0,
|
72
114
|
size_t n, const float *x,
|
@@ -152,7 +194,7 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
|
|
152
194
|
VisitedTable vt (ntotal);
|
153
195
|
|
154
196
|
DistanceComputer *dis =
|
155
|
-
index_hnsw.storage
|
197
|
+
storage_distance_computer (index_hnsw.storage);
|
156
198
|
ScopeDeleter1<DistanceComputer> del(dis);
|
157
199
|
int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
|
158
200
|
size_t counter = 0;
|
@@ -210,8 +252,8 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
|
|
210
252
|
* IndexHNSW implementation
|
211
253
|
**************************************************************/
|
212
254
|
|
213
|
-
IndexHNSW::IndexHNSW(int d, int M):
|
214
|
-
Index(d,
|
255
|
+
IndexHNSW::IndexHNSW(int d, int M, MetricType metric):
|
256
|
+
Index(d, metric),
|
215
257
|
hnsw(M),
|
216
258
|
own_fields(false),
|
217
259
|
storage(nullptr),
|
@@ -258,7 +300,8 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
|
|
258
300
|
#pragma omp parallel reduction(+ : nreorder)
|
259
301
|
{
|
260
302
|
VisitedTable vt (ntotal);
|
261
|
-
|
303
|
+
|
304
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
262
305
|
ScopeDeleter1<DistanceComputer> del(dis);
|
263
306
|
|
264
307
|
#pragma omp for
|
@@ -290,6 +333,14 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
|
|
290
333
|
}
|
291
334
|
InterruptCallback::check ();
|
292
335
|
}
|
336
|
+
|
337
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
338
|
+
// we need to revert the negated distances
|
339
|
+
for (size_t i = 0; i < k * n; i++) {
|
340
|
+
distances[i] = -distances[i];
|
341
|
+
}
|
342
|
+
}
|
343
|
+
|
293
344
|
hnsw_stats.nreorder += nreorder;
|
294
345
|
}
|
295
346
|
|
@@ -323,7 +374,7 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size)
|
|
323
374
|
{
|
324
375
|
#pragma omp parallel
|
325
376
|
{
|
326
|
-
DistanceComputer *dis = storage
|
377
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
327
378
|
ScopeDeleter1<DistanceComputer> del(dis);
|
328
379
|
|
329
380
|
#pragma omp for
|
@@ -367,7 +418,7 @@ void IndexHNSW::search_level_0(
|
|
367
418
|
storage_idx_t ntotal = hnsw.levels.size();
|
368
419
|
#pragma omp parallel
|
369
420
|
{
|
370
|
-
DistanceComputer *qdis = storage
|
421
|
+
DistanceComputer *qdis = storage_distance_computer(storage);
|
371
422
|
ScopeDeleter1<DistanceComputer> del(qdis);
|
372
423
|
|
373
424
|
VisitedTable vt (ntotal);
|
@@ -436,7 +487,7 @@ void IndexHNSW::init_level_0_from_knngraph(
|
|
436
487
|
|
437
488
|
#pragma omp parallel for
|
438
489
|
for (idx_t i = 0; i < ntotal; i++) {
|
439
|
-
DistanceComputer *qdis = storage
|
490
|
+
DistanceComputer *qdis = storage_distance_computer(storage);
|
440
491
|
float vec[d];
|
441
492
|
storage->reconstruct(i, vec);
|
442
493
|
qdis->set_query(vec);
|
@@ -480,7 +531,7 @@ void IndexHNSW::init_level_0_from_entry_points(
|
|
480
531
|
{
|
481
532
|
VisitedTable vt (ntotal);
|
482
533
|
|
483
|
-
DistanceComputer *dis = storage
|
534
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
484
535
|
ScopeDeleter1<DistanceComputer> del(dis);
|
485
536
|
float vec[storage->d];
|
486
537
|
|
@@ -518,7 +569,7 @@ void IndexHNSW::reorder_links()
|
|
518
569
|
std::vector<float> distances (M);
|
519
570
|
std::vector<size_t> order (M);
|
520
571
|
std::vector<storage_idx_t> tmp (M);
|
521
|
-
DistanceComputer *dis = storage
|
572
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
522
573
|
ScopeDeleter1<DistanceComputer> del(dis);
|
523
574
|
|
524
575
|
#pragma omp for
|
@@ -826,8 +877,8 @@ IndexHNSWFlat::IndexHNSWFlat()
|
|
826
877
|
is_trained = true;
|
827
878
|
}
|
828
879
|
|
829
|
-
IndexHNSWFlat::IndexHNSWFlat(int d, int M):
|
830
|
-
IndexHNSW(new
|
880
|
+
IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric):
|
881
|
+
IndexHNSW(new IndexFlat(d, metric), M)
|
831
882
|
{
|
832
883
|
own_fields = true;
|
833
884
|
is_trained = true;
|
@@ -860,8 +911,9 @@ void IndexHNSWPQ::train(idx_t n, const float* x)
|
|
860
911
|
**************************************************************/
|
861
912
|
|
862
913
|
|
863
|
-
IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M
|
864
|
-
|
914
|
+
IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M,
|
915
|
+
MetricType metric):
|
916
|
+
IndexHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
|
865
917
|
{
|
866
918
|
is_trained = false;
|
867
919
|
own_fields = true;
|
@@ -986,7 +1038,7 @@ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
|
|
986
1038
|
#pragma omp parallel
|
987
1039
|
{
|
988
1040
|
VisitedTable vt (ntotal);
|
989
|
-
DistanceComputer *dis = storage
|
1041
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
990
1042
|
ScopeDeleter1<DistanceComputer> del(dis);
|
991
1043
|
|
992
1044
|
int candidates_size = hnsw.upper_beam;
|
data/vendor/faiss/IndexHNSW.h
CHANGED
@@ -79,7 +79,7 @@ struct IndexHNSW : Index {
|
|
79
79
|
|
80
80
|
ReconstructFromNeighbors *reconstruct_from_neighbors;
|
81
81
|
|
82
|
-
explicit IndexHNSW (int d = 0, int M = 32);
|
82
|
+
explicit IndexHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
|
83
83
|
explicit IndexHNSW (Index *storage, int M = 32);
|
84
84
|
|
85
85
|
~IndexHNSW() override;
|
@@ -132,7 +132,7 @@ struct IndexHNSW : Index {
|
|
132
132
|
|
133
133
|
struct IndexHNSWFlat : IndexHNSW {
|
134
134
|
IndexHNSWFlat();
|
135
|
-
IndexHNSWFlat(int d, int M);
|
135
|
+
IndexHNSWFlat(int d, int M, MetricType metric = METRIC_L2);
|
136
136
|
};
|
137
137
|
|
138
138
|
/** PQ index topped with with a HNSW structure to access elements
|
@@ -149,7 +149,7 @@ struct IndexHNSWPQ : IndexHNSW {
|
|
149
149
|
*/
|
150
150
|
struct IndexHNSWSQ : IndexHNSW {
|
151
151
|
IndexHNSWSQ();
|
152
|
-
IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M);
|
152
|
+
IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M, MetricType metric = METRIC_L2);
|
153
153
|
};
|
154
154
|
|
155
155
|
/** 2-level code structure with fast random access
|
data/vendor/faiss/IndexIVF.cpp
CHANGED
@@ -157,8 +157,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
|
|
157
157
|
code_size (code_size),
|
158
158
|
nprobe (1),
|
159
159
|
max_codes (0),
|
160
|
-
parallel_mode (0)
|
161
|
-
maintain_direct_map (false)
|
160
|
+
parallel_mode (0)
|
162
161
|
{
|
163
162
|
FAISS_THROW_IF_NOT (d == quantizer->d);
|
164
163
|
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
|
@@ -172,8 +171,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
|
|
172
171
|
IndexIVF::IndexIVF ():
|
173
172
|
invlists (nullptr), own_invlists (false),
|
174
173
|
code_size (0),
|
175
|
-
nprobe (1), max_codes (0), parallel_mode (0)
|
176
|
-
maintain_direct_map (false)
|
174
|
+
nprobe (1), max_codes (0), parallel_mode (0)
|
177
175
|
{}
|
178
176
|
|
179
177
|
void IndexIVF::add (idx_t n, const float * x)
|
@@ -199,6 +197,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
199
197
|
}
|
200
198
|
|
201
199
|
FAISS_THROW_IF_NOT (is_trained);
|
200
|
+
direct_map.check_can_add (xids);
|
201
|
+
|
202
202
|
std::unique_ptr<idx_t []> idx(new idx_t[n]);
|
203
203
|
quantizer->assign (n, x, idx.get());
|
204
204
|
size_t nadd = 0, nminus1 = 0;
|
@@ -210,6 +210,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
210
210
|
std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
|
211
211
|
encode_vectors (n, x, idx.get(), flat_codes.get());
|
212
212
|
|
213
|
+
DirectMapAdd dm_adder(direct_map, n, xids);
|
214
|
+
|
213
215
|
#pragma omp parallel reduction(+: nadd)
|
214
216
|
{
|
215
217
|
int nt = omp_get_num_threads();
|
@@ -220,13 +222,21 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
220
222
|
idx_t list_no = idx [i];
|
221
223
|
if (list_no >= 0 && list_no % nt == rank) {
|
222
224
|
idx_t id = xids ? xids[i] : ntotal + i;
|
223
|
-
invlists->add_entry (
|
224
|
-
|
225
|
+
size_t ofs = invlists->add_entry (
|
226
|
+
list_no, id,
|
227
|
+
flat_codes.get() + i * code_size
|
228
|
+
);
|
229
|
+
|
230
|
+
dm_adder.add (i, list_no, ofs);
|
231
|
+
|
225
232
|
nadd++;
|
233
|
+
} else if (rank == 0 && list_no == -1) {
|
234
|
+
dm_adder.add (i, -1, 0);
|
226
235
|
}
|
227
236
|
}
|
228
237
|
}
|
229
238
|
|
239
|
+
|
230
240
|
if (verbose) {
|
231
241
|
printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1);
|
232
242
|
}
|
@@ -234,30 +244,18 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
234
244
|
ntotal += n;
|
235
245
|
}
|
236
246
|
|
237
|
-
|
238
|
-
void IndexIVF::make_direct_map (bool new_maintain_direct_map)
|
247
|
+
void IndexIVF::make_direct_map (bool b)
|
239
248
|
{
|
240
|
-
|
241
|
-
|
242
|
-
return;
|
243
|
-
|
244
|
-
if (new_maintain_direct_map) {
|
245
|
-
direct_map.resize (ntotal, -1);
|
246
|
-
for (size_t key = 0; key < nlist; key++) {
|
247
|
-
size_t list_size = invlists->list_size (key);
|
248
|
-
ScopedIds idlist (invlists, key);
|
249
|
-
|
250
|
-
for (long ofs = 0; ofs < list_size; ofs++) {
|
251
|
-
FAISS_THROW_IF_NOT_MSG (
|
252
|
-
0 <= idlist [ofs] && idlist[ofs] < ntotal,
|
253
|
-
"direct map supported only for seuquential ids");
|
254
|
-
direct_map [idlist [ofs]] = key << 32 | ofs;
|
255
|
-
}
|
256
|
-
}
|
249
|
+
if (b) {
|
250
|
+
direct_map.set_type (DirectMap::Array, invlists, ntotal);
|
257
251
|
} else {
|
258
|
-
direct_map.
|
252
|
+
direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
|
259
253
|
}
|
260
|
-
|
254
|
+
}
|
255
|
+
|
256
|
+
void IndexIVF::set_direct_map_type (DirectMap::Type type)
|
257
|
+
{
|
258
|
+
direct_map.set_type (type, invlists, ntotal);
|
261
259
|
}
|
262
260
|
|
263
261
|
|
@@ -298,10 +296,13 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
298
296
|
|
299
297
|
bool interrupt = false;
|
300
298
|
|
299
|
+
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
300
|
+
bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
|
301
|
+
|
301
302
|
// don't start parallel section if single query
|
302
303
|
bool do_parallel =
|
303
|
-
|
304
|
-
|
304
|
+
pmode == 0 ? n > 1 :
|
305
|
+
pmode == 1 ? nprobe > 1 :
|
305
306
|
nprobe * n > 1;
|
306
307
|
|
307
308
|
#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
|
@@ -318,6 +319,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
318
319
|
// intialize + reorder a result heap
|
319
320
|
|
320
321
|
auto init_result = [&](float *simi, idx_t *idxi) {
|
322
|
+
if (!do_heap_init) return;
|
321
323
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
322
324
|
heap_heapify<HeapForIP> (k, simi, idxi);
|
323
325
|
} else {
|
@@ -326,6 +328,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
326
328
|
};
|
327
329
|
|
328
330
|
auto reorder_result = [&] (float *simi, idx_t *idxi) {
|
331
|
+
if (!do_heap_init) return;
|
329
332
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
330
333
|
heap_reorder<HeapForIP> (k, simi, idxi);
|
331
334
|
} else {
|
@@ -377,7 +380,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
377
380
|
* Actual loops, depending on parallel_mode
|
378
381
|
****************************************************/
|
379
382
|
|
380
|
-
if (
|
383
|
+
if (pmode == 0) {
|
381
384
|
|
382
385
|
#pragma omp for
|
383
386
|
for (size_t i = 0; i < n; i++) {
|
@@ -417,7 +420,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
417
420
|
}
|
418
421
|
|
419
422
|
} // parallel for
|
420
|
-
} else if (
|
423
|
+
} else if (pmode == 1) {
|
421
424
|
std::vector <idx_t> local_idx (k);
|
422
425
|
std::vector <float> local_dis (k);
|
423
426
|
|
@@ -460,7 +463,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
460
463
|
}
|
461
464
|
} else {
|
462
465
|
FAISS_THROW_FMT ("parallel_mode %d not supported\n",
|
463
|
-
|
466
|
+
pmode);
|
464
467
|
}
|
465
468
|
} // parallel section
|
466
469
|
|
@@ -608,13 +611,8 @@ InvertedListScanner *IndexIVF::get_InvertedListScanner (
|
|
608
611
|
|
609
612
|
void IndexIVF::reconstruct (idx_t key, float* recons) const
|
610
613
|
{
|
611
|
-
|
612
|
-
|
613
|
-
FAISS_THROW_IF_NOT_MSG (key >= 0 && key < direct_map.size(),
|
614
|
-
"invalid key");
|
615
|
-
idx_t list_no = direct_map[key] >> 32;
|
616
|
-
idx_t offset = direct_map[key] & 0xffffffff;
|
617
|
-
reconstruct_from_offset (list_no, offset, recons);
|
614
|
+
idx_t lo = direct_map.get (key);
|
615
|
+
reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
|
618
616
|
}
|
619
617
|
|
620
618
|
|
@@ -682,8 +680,8 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
|
|
682
680
|
// Fill with NaNs
|
683
681
|
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
684
682
|
} else {
|
685
|
-
int list_no = key
|
686
|
-
int offset = key
|
683
|
+
int list_no = lo_listno (key);
|
684
|
+
int offset = lo_offset (key);
|
687
685
|
|
688
686
|
// Update label to the actual id
|
689
687
|
labels[ij] = invlists->get_single_id (list_no, offset);
|
@@ -711,42 +709,41 @@ void IndexIVF::reset ()
|
|
711
709
|
|
712
710
|
size_t IndexIVF::remove_ids (const IDSelector & sel)
|
713
711
|
{
|
714
|
-
|
715
|
-
"direct map remove not implemented");
|
716
|
-
|
717
|
-
std::vector<idx_t> toremove(nlist);
|
718
|
-
|
719
|
-
#pragma omp parallel for
|
720
|
-
for (idx_t i = 0; i < nlist; i++) {
|
721
|
-
idx_t l0 = invlists->list_size (i), l = l0, j = 0;
|
722
|
-
ScopedIds idsi (invlists, i);
|
723
|
-
while (j < l) {
|
724
|
-
if (sel.is_member (idsi[j])) {
|
725
|
-
l--;
|
726
|
-
invlists->update_entry (
|
727
|
-
i, j,
|
728
|
-
invlists->get_single_id (i, l),
|
729
|
-
ScopedCodes (invlists, i, l).get());
|
730
|
-
} else {
|
731
|
-
j++;
|
732
|
-
}
|
733
|
-
}
|
734
|
-
toremove[i] = l0 - l;
|
735
|
-
}
|
736
|
-
// this will not run well in parallel on ondisk because of possible shrinks
|
737
|
-
size_t nremove = 0;
|
738
|
-
for (idx_t i = 0; i < nlist; i++) {
|
739
|
-
if (toremove[i] > 0) {
|
740
|
-
nremove += toremove[i];
|
741
|
-
invlists->resize(
|
742
|
-
i, invlists->list_size(i) - toremove[i]);
|
743
|
-
}
|
744
|
-
}
|
712
|
+
size_t nremove = direct_map.remove_ids (sel, invlists);
|
745
713
|
ntotal -= nremove;
|
746
714
|
return nremove;
|
747
715
|
}
|
748
716
|
|
749
717
|
|
718
|
+
void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
|
719
|
+
{
|
720
|
+
|
721
|
+
if (direct_map.type == DirectMap::Hashtable) {
|
722
|
+
// just remove then add
|
723
|
+
IDSelectorArray sel(n, new_ids);
|
724
|
+
size_t nremove = remove_ids (sel);
|
725
|
+
FAISS_THROW_IF_NOT_MSG (nremove == n,
|
726
|
+
"did not find all entries to remove");
|
727
|
+
add_with_ids (n, x, new_ids);
|
728
|
+
return;
|
729
|
+
}
|
730
|
+
|
731
|
+
FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array);
|
732
|
+
// here it is more tricky because we don't want to introduce holes
|
733
|
+
// in continuous range of ids
|
734
|
+
|
735
|
+
FAISS_THROW_IF_NOT (is_trained);
|
736
|
+
std::vector<idx_t> assign (n);
|
737
|
+
quantizer->assign (n, x, assign.data());
|
738
|
+
|
739
|
+
std::vector<uint8_t> flat_codes (n * code_size);
|
740
|
+
encode_vectors (n, x, assign.data(), flat_codes.data());
|
741
|
+
|
742
|
+
direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data());
|
743
|
+
|
744
|
+
}
|
745
|
+
|
746
|
+
|
750
747
|
|
751
748
|
|
752
749
|
void IndexIVF::train (idx_t n, const float *x)
|
@@ -779,15 +776,14 @@ void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
|
|
779
776
|
FAISS_THROW_IF_NOT (other.code_size == code_size);
|
780
777
|
FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
|
781
778
|
"can only merge indexes of the same type");
|
779
|
+
FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(),
|
780
|
+
"merge direct_map not implemented");
|
782
781
|
}
|
783
782
|
|
784
783
|
|
785
784
|
void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
|
786
785
|
{
|
787
786
|
check_compatible_for_merge (other);
|
788
|
-
FAISS_THROW_IF_NOT_MSG ((!maintain_direct_map &&
|
789
|
-
!other.maintain_direct_map),
|
790
|
-
"direct map copy not implemented");
|
791
787
|
|
792
788
|
invlists->merge_from (other.invlists, add_id);
|
793
789
|
|
@@ -817,7 +813,7 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
|
817
813
|
|
818
814
|
FAISS_THROW_IF_NOT (nlist == other.nlist);
|
819
815
|
FAISS_THROW_IF_NOT (code_size == other.code_size);
|
820
|
-
FAISS_THROW_IF_NOT (
|
816
|
+
FAISS_THROW_IF_NOT (other.direct_map.no());
|
821
817
|
FAISS_THROW_IF_NOT_FMT (
|
822
818
|
subset_type == 0 || subset_type == 1 || subset_type == 2,
|
823
819
|
"subset type %d not implemented", subset_type);
|