faiss 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +18 -18
- data/README.md +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/Clustering.cpp +318 -53
- data/vendor/faiss/Clustering.h +39 -11
- data/vendor/faiss/DirectMap.cpp +267 -0
- data/vendor/faiss/DirectMap.h +120 -0
- data/vendor/faiss/IVFlib.cpp +24 -4
- data/vendor/faiss/IVFlib.h +4 -0
- data/vendor/faiss/Index.h +5 -24
- data/vendor/faiss/Index2Layer.cpp +0 -1
- data/vendor/faiss/IndexBinary.h +7 -3
- data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
- data/vendor/faiss/IndexBinaryFlat.h +3 -0
- data/vendor/faiss/IndexBinaryHash.cpp +492 -0
- data/vendor/faiss/IndexBinaryHash.h +116 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
- data/vendor/faiss/IndexBinaryIVF.h +14 -4
- data/vendor/faiss/IndexFlat.h +2 -1
- data/vendor/faiss/IndexHNSW.cpp +68 -16
- data/vendor/faiss/IndexHNSW.h +3 -3
- data/vendor/faiss/IndexIVF.cpp +72 -76
- data/vendor/faiss/IndexIVF.h +24 -5
- data/vendor/faiss/IndexIVFFlat.cpp +19 -54
- data/vendor/faiss/IndexIVFFlat.h +1 -11
- data/vendor/faiss/IndexIVFPQ.cpp +49 -26
- data/vendor/faiss/IndexIVFPQ.h +9 -10
- data/vendor/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
- data/vendor/faiss/IndexLSH.h +4 -1
- data/vendor/faiss/IndexPreTransform.cpp +0 -1
- data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
- data/vendor/faiss/InvertedLists.cpp +0 -2
- data/vendor/faiss/MetaIndexes.cpp +0 -1
- data/vendor/faiss/MetricType.h +36 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
- data/vendor/faiss/c_api/Clustering_c.h +11 -5
- data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
- data/vendor/faiss/gpu/GpuDistance.h +93 -0
- data/vendor/faiss/gpu/GpuIndex.h +7 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
- data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
- data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
- data/vendor/faiss/impl/HNSW.cpp +0 -1
- data/vendor/faiss/impl/PolysemousTraining.h +5 -5
- data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
- data/vendor/faiss/impl/ProductQuantizer.h +42 -47
- data/vendor/faiss/impl/index_read.cpp +103 -7
- data/vendor/faiss/impl/index_write.cpp +101 -5
- data/vendor/faiss/impl/io.cpp +111 -1
- data/vendor/faiss/impl/io.h +38 -0
- data/vendor/faiss/index_factory.cpp +0 -1
- data/vendor/faiss/tests/test_merge.cpp +0 -1
- data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
- data/vendor/faiss/utils/distances.cpp +4 -5
- data/vendor/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/utils/hamming.cpp +85 -3
- data/vendor/faiss/utils/hamming.h +20 -0
- data/vendor/faiss/utils/utils.cpp +0 -96
- data/vendor/faiss/utils/utils.h +0 -15
- metadata +11 -3
- data/lib/faiss/ext.bundle +0 -0
@@ -46,8 +46,7 @@ struct IndexBinaryIVF : IndexBinary {
|
|
46
46
|
bool use_heap = true;
|
47
47
|
|
48
48
|
/// map for direct access to the elements. Enables reconstruct().
|
49
|
-
|
50
|
-
std::vector<idx_t> direct_map;
|
49
|
+
DirectMap direct_map;
|
51
50
|
|
52
51
|
IndexBinary *quantizer; ///< quantizer that maps vectors to inverted lists
|
53
52
|
size_t nlist; ///< number of possible key values
|
@@ -110,8 +109,11 @@ struct IndexBinaryIVF : IndexBinary {
|
|
110
109
|
bool store_pairs=false) const;
|
111
110
|
|
112
111
|
/** assign the vectors, then call search_preassign */
|
113
|
-
|
114
|
-
|
112
|
+
void search(idx_t n, const uint8_t *x, idx_t k,
|
113
|
+
int32_t *distances, idx_t *labels) const override;
|
114
|
+
|
115
|
+
void range_search(idx_t n, const uint8_t *x, int radius,
|
116
|
+
RangeSearchResult *result) const override;
|
115
117
|
|
116
118
|
void reconstruct(idx_t key, uint8_t *recons) const override;
|
117
119
|
|
@@ -168,6 +170,8 @@ struct IndexBinaryIVF : IndexBinary {
|
|
168
170
|
*/
|
169
171
|
void make_direct_map(bool new_maintain_direct_map=true);
|
170
172
|
|
173
|
+
void set_direct_map_type (DirectMap::Type type);
|
174
|
+
|
171
175
|
void replace_invlists(InvertedLists *il, bool own=false);
|
172
176
|
};
|
173
177
|
|
@@ -201,6 +205,12 @@ struct BinaryInvertedListScanner {
|
|
201
205
|
int32_t *distances, idx_t *labels,
|
202
206
|
size_t k) const = 0;
|
203
207
|
|
208
|
+
virtual void scan_codes_range (size_t n,
|
209
|
+
const uint8_t *codes,
|
210
|
+
const idx_t *ids,
|
211
|
+
int radius,
|
212
|
+
RangeQueryResult &result) const = 0;
|
213
|
+
|
204
214
|
virtual ~BinaryInvertedListScanner () {}
|
205
215
|
|
206
216
|
};
|
data/vendor/faiss/IndexFlat.h
CHANGED
@@ -19,6 +19,7 @@ namespace faiss {
|
|
19
19
|
|
20
20
|
/** Index that stores the full vectors and performs exhaustive search */
|
21
21
|
struct IndexFlat: Index {
|
22
|
+
|
22
23
|
/// database vectors, size ntotal * d
|
23
24
|
std::vector<float> xb;
|
24
25
|
|
@@ -144,7 +145,7 @@ struct IndexRefineFlat: Index {
|
|
144
145
|
};
|
145
146
|
|
146
147
|
|
147
|
-
/// optimized version for 1D "vectors"
|
148
|
+
/// optimized version for 1D "vectors".
|
148
149
|
struct IndexFlat1D:IndexFlatL2 {
|
149
150
|
bool continuous_update; ///< is the permutation updated continuously?
|
150
151
|
|
data/vendor/faiss/IndexHNSW.cpp
CHANGED
@@ -26,7 +26,6 @@
|
|
26
26
|
#include <stdint.h>
|
27
27
|
|
28
28
|
#ifdef __SSE__
|
29
|
-
#include <immintrin.h>
|
30
29
|
#endif
|
31
30
|
|
32
31
|
#include <faiss/utils/distances.h>
|
@@ -55,7 +54,6 @@ namespace faiss {
|
|
55
54
|
using idx_t = Index::idx_t;
|
56
55
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
57
56
|
using storage_idx_t = HNSW::storage_idx_t;
|
58
|
-
using NodeDistCloser = HNSW::NodeDistCloser;
|
59
57
|
using NodeDistFarther = HNSW::NodeDistFarther;
|
60
58
|
|
61
59
|
HNSWStats hnsw_stats;
|
@@ -67,6 +65,50 @@ HNSWStats hnsw_stats;
|
|
67
65
|
namespace {
|
68
66
|
|
69
67
|
|
68
|
+
/* Wrap the distance computer into one that negates the
|
69
|
+
distances. This makes supporting INNER_PRODUCE search easier */
|
70
|
+
|
71
|
+
struct NegativeDistanceComputer: DistanceComputer {
|
72
|
+
|
73
|
+
/// owned by this
|
74
|
+
DistanceComputer *basedis;
|
75
|
+
|
76
|
+
explicit NegativeDistanceComputer(DistanceComputer *basedis):
|
77
|
+
basedis(basedis)
|
78
|
+
{}
|
79
|
+
|
80
|
+
void set_query(const float *x) override {
|
81
|
+
basedis->set_query(x);
|
82
|
+
}
|
83
|
+
|
84
|
+
/// compute distance of vector i to current query
|
85
|
+
float operator () (idx_t i) override {
|
86
|
+
return -(*basedis)(i);
|
87
|
+
}
|
88
|
+
|
89
|
+
/// compute distance between two stored vectors
|
90
|
+
float symmetric_dis (idx_t i, idx_t j) override {
|
91
|
+
return -basedis->symmetric_dis(i, j);
|
92
|
+
}
|
93
|
+
|
94
|
+
virtual ~NegativeDistanceComputer ()
|
95
|
+
{
|
96
|
+
delete basedis;
|
97
|
+
}
|
98
|
+
|
99
|
+
};
|
100
|
+
|
101
|
+
DistanceComputer *storage_distance_computer(const Index *storage)
|
102
|
+
{
|
103
|
+
if (storage->metric_type == METRIC_INNER_PRODUCT) {
|
104
|
+
return new NegativeDistanceComputer(storage->get_distance_computer());
|
105
|
+
} else {
|
106
|
+
return storage->get_distance_computer();
|
107
|
+
}
|
108
|
+
}
|
109
|
+
|
110
|
+
|
111
|
+
|
70
112
|
void hnsw_add_vertices(IndexHNSW &index_hnsw,
|
71
113
|
size_t n0,
|
72
114
|
size_t n, const float *x,
|
@@ -152,7 +194,7 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
|
|
152
194
|
VisitedTable vt (ntotal);
|
153
195
|
|
154
196
|
DistanceComputer *dis =
|
155
|
-
index_hnsw.storage
|
197
|
+
storage_distance_computer (index_hnsw.storage);
|
156
198
|
ScopeDeleter1<DistanceComputer> del(dis);
|
157
199
|
int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
|
158
200
|
size_t counter = 0;
|
@@ -210,8 +252,8 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
|
|
210
252
|
* IndexHNSW implementation
|
211
253
|
**************************************************************/
|
212
254
|
|
213
|
-
IndexHNSW::IndexHNSW(int d, int M):
|
214
|
-
Index(d,
|
255
|
+
IndexHNSW::IndexHNSW(int d, int M, MetricType metric):
|
256
|
+
Index(d, metric),
|
215
257
|
hnsw(M),
|
216
258
|
own_fields(false),
|
217
259
|
storage(nullptr),
|
@@ -258,7 +300,8 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
|
|
258
300
|
#pragma omp parallel reduction(+ : nreorder)
|
259
301
|
{
|
260
302
|
VisitedTable vt (ntotal);
|
261
|
-
|
303
|
+
|
304
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
262
305
|
ScopeDeleter1<DistanceComputer> del(dis);
|
263
306
|
|
264
307
|
#pragma omp for
|
@@ -290,6 +333,14 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
|
|
290
333
|
}
|
291
334
|
InterruptCallback::check ();
|
292
335
|
}
|
336
|
+
|
337
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
338
|
+
// we need to revert the negated distances
|
339
|
+
for (size_t i = 0; i < k * n; i++) {
|
340
|
+
distances[i] = -distances[i];
|
341
|
+
}
|
342
|
+
}
|
343
|
+
|
293
344
|
hnsw_stats.nreorder += nreorder;
|
294
345
|
}
|
295
346
|
|
@@ -323,7 +374,7 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size)
|
|
323
374
|
{
|
324
375
|
#pragma omp parallel
|
325
376
|
{
|
326
|
-
DistanceComputer *dis = storage
|
377
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
327
378
|
ScopeDeleter1<DistanceComputer> del(dis);
|
328
379
|
|
329
380
|
#pragma omp for
|
@@ -367,7 +418,7 @@ void IndexHNSW::search_level_0(
|
|
367
418
|
storage_idx_t ntotal = hnsw.levels.size();
|
368
419
|
#pragma omp parallel
|
369
420
|
{
|
370
|
-
DistanceComputer *qdis = storage
|
421
|
+
DistanceComputer *qdis = storage_distance_computer(storage);
|
371
422
|
ScopeDeleter1<DistanceComputer> del(qdis);
|
372
423
|
|
373
424
|
VisitedTable vt (ntotal);
|
@@ -436,7 +487,7 @@ void IndexHNSW::init_level_0_from_knngraph(
|
|
436
487
|
|
437
488
|
#pragma omp parallel for
|
438
489
|
for (idx_t i = 0; i < ntotal; i++) {
|
439
|
-
DistanceComputer *qdis = storage
|
490
|
+
DistanceComputer *qdis = storage_distance_computer(storage);
|
440
491
|
float vec[d];
|
441
492
|
storage->reconstruct(i, vec);
|
442
493
|
qdis->set_query(vec);
|
@@ -480,7 +531,7 @@ void IndexHNSW::init_level_0_from_entry_points(
|
|
480
531
|
{
|
481
532
|
VisitedTable vt (ntotal);
|
482
533
|
|
483
|
-
DistanceComputer *dis = storage
|
534
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
484
535
|
ScopeDeleter1<DistanceComputer> del(dis);
|
485
536
|
float vec[storage->d];
|
486
537
|
|
@@ -518,7 +569,7 @@ void IndexHNSW::reorder_links()
|
|
518
569
|
std::vector<float> distances (M);
|
519
570
|
std::vector<size_t> order (M);
|
520
571
|
std::vector<storage_idx_t> tmp (M);
|
521
|
-
DistanceComputer *dis = storage
|
572
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
522
573
|
ScopeDeleter1<DistanceComputer> del(dis);
|
523
574
|
|
524
575
|
#pragma omp for
|
@@ -826,8 +877,8 @@ IndexHNSWFlat::IndexHNSWFlat()
|
|
826
877
|
is_trained = true;
|
827
878
|
}
|
828
879
|
|
829
|
-
IndexHNSWFlat::IndexHNSWFlat(int d, int M):
|
830
|
-
IndexHNSW(new
|
880
|
+
IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric):
|
881
|
+
IndexHNSW(new IndexFlat(d, metric), M)
|
831
882
|
{
|
832
883
|
own_fields = true;
|
833
884
|
is_trained = true;
|
@@ -860,8 +911,9 @@ void IndexHNSWPQ::train(idx_t n, const float* x)
|
|
860
911
|
**************************************************************/
|
861
912
|
|
862
913
|
|
863
|
-
IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M
|
864
|
-
|
914
|
+
IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M,
|
915
|
+
MetricType metric):
|
916
|
+
IndexHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
|
865
917
|
{
|
866
918
|
is_trained = false;
|
867
919
|
own_fields = true;
|
@@ -986,7 +1038,7 @@ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
|
|
986
1038
|
#pragma omp parallel
|
987
1039
|
{
|
988
1040
|
VisitedTable vt (ntotal);
|
989
|
-
DistanceComputer *dis = storage
|
1041
|
+
DistanceComputer *dis = storage_distance_computer(storage);
|
990
1042
|
ScopeDeleter1<DistanceComputer> del(dis);
|
991
1043
|
|
992
1044
|
int candidates_size = hnsw.upper_beam;
|
data/vendor/faiss/IndexHNSW.h
CHANGED
@@ -79,7 +79,7 @@ struct IndexHNSW : Index {
|
|
79
79
|
|
80
80
|
ReconstructFromNeighbors *reconstruct_from_neighbors;
|
81
81
|
|
82
|
-
explicit IndexHNSW (int d = 0, int M = 32);
|
82
|
+
explicit IndexHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
|
83
83
|
explicit IndexHNSW (Index *storage, int M = 32);
|
84
84
|
|
85
85
|
~IndexHNSW() override;
|
@@ -132,7 +132,7 @@ struct IndexHNSW : Index {
|
|
132
132
|
|
133
133
|
struct IndexHNSWFlat : IndexHNSW {
|
134
134
|
IndexHNSWFlat();
|
135
|
-
IndexHNSWFlat(int d, int M);
|
135
|
+
IndexHNSWFlat(int d, int M, MetricType metric = METRIC_L2);
|
136
136
|
};
|
137
137
|
|
138
138
|
/** PQ index topped with with a HNSW structure to access elements
|
@@ -149,7 +149,7 @@ struct IndexHNSWPQ : IndexHNSW {
|
|
149
149
|
*/
|
150
150
|
struct IndexHNSWSQ : IndexHNSW {
|
151
151
|
IndexHNSWSQ();
|
152
|
-
IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M);
|
152
|
+
IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M, MetricType metric = METRIC_L2);
|
153
153
|
};
|
154
154
|
|
155
155
|
/** 2-level code structure with fast random access
|
data/vendor/faiss/IndexIVF.cpp
CHANGED
@@ -157,8 +157,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
|
|
157
157
|
code_size (code_size),
|
158
158
|
nprobe (1),
|
159
159
|
max_codes (0),
|
160
|
-
parallel_mode (0)
|
161
|
-
maintain_direct_map (false)
|
160
|
+
parallel_mode (0)
|
162
161
|
{
|
163
162
|
FAISS_THROW_IF_NOT (d == quantizer->d);
|
164
163
|
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
|
@@ -172,8 +171,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
|
|
172
171
|
IndexIVF::IndexIVF ():
|
173
172
|
invlists (nullptr), own_invlists (false),
|
174
173
|
code_size (0),
|
175
|
-
nprobe (1), max_codes (0), parallel_mode (0)
|
176
|
-
maintain_direct_map (false)
|
174
|
+
nprobe (1), max_codes (0), parallel_mode (0)
|
177
175
|
{}
|
178
176
|
|
179
177
|
void IndexIVF::add (idx_t n, const float * x)
|
@@ -199,6 +197,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
199
197
|
}
|
200
198
|
|
201
199
|
FAISS_THROW_IF_NOT (is_trained);
|
200
|
+
direct_map.check_can_add (xids);
|
201
|
+
|
202
202
|
std::unique_ptr<idx_t []> idx(new idx_t[n]);
|
203
203
|
quantizer->assign (n, x, idx.get());
|
204
204
|
size_t nadd = 0, nminus1 = 0;
|
@@ -210,6 +210,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
210
210
|
std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
|
211
211
|
encode_vectors (n, x, idx.get(), flat_codes.get());
|
212
212
|
|
213
|
+
DirectMapAdd dm_adder(direct_map, n, xids);
|
214
|
+
|
213
215
|
#pragma omp parallel reduction(+: nadd)
|
214
216
|
{
|
215
217
|
int nt = omp_get_num_threads();
|
@@ -220,13 +222,21 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
220
222
|
idx_t list_no = idx [i];
|
221
223
|
if (list_no >= 0 && list_no % nt == rank) {
|
222
224
|
idx_t id = xids ? xids[i] : ntotal + i;
|
223
|
-
invlists->add_entry (
|
224
|
-
|
225
|
+
size_t ofs = invlists->add_entry (
|
226
|
+
list_no, id,
|
227
|
+
flat_codes.get() + i * code_size
|
228
|
+
);
|
229
|
+
|
230
|
+
dm_adder.add (i, list_no, ofs);
|
231
|
+
|
225
232
|
nadd++;
|
233
|
+
} else if (rank == 0 && list_no == -1) {
|
234
|
+
dm_adder.add (i, -1, 0);
|
226
235
|
}
|
227
236
|
}
|
228
237
|
}
|
229
238
|
|
239
|
+
|
230
240
|
if (verbose) {
|
231
241
|
printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1);
|
232
242
|
}
|
@@ -234,30 +244,18 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
234
244
|
ntotal += n;
|
235
245
|
}
|
236
246
|
|
237
|
-
|
238
|
-
void IndexIVF::make_direct_map (bool new_maintain_direct_map)
|
247
|
+
void IndexIVF::make_direct_map (bool b)
|
239
248
|
{
|
240
|
-
|
241
|
-
|
242
|
-
return;
|
243
|
-
|
244
|
-
if (new_maintain_direct_map) {
|
245
|
-
direct_map.resize (ntotal, -1);
|
246
|
-
for (size_t key = 0; key < nlist; key++) {
|
247
|
-
size_t list_size = invlists->list_size (key);
|
248
|
-
ScopedIds idlist (invlists, key);
|
249
|
-
|
250
|
-
for (long ofs = 0; ofs < list_size; ofs++) {
|
251
|
-
FAISS_THROW_IF_NOT_MSG (
|
252
|
-
0 <= idlist [ofs] && idlist[ofs] < ntotal,
|
253
|
-
"direct map supported only for seuquential ids");
|
254
|
-
direct_map [idlist [ofs]] = key << 32 | ofs;
|
255
|
-
}
|
256
|
-
}
|
249
|
+
if (b) {
|
250
|
+
direct_map.set_type (DirectMap::Array, invlists, ntotal);
|
257
251
|
} else {
|
258
|
-
direct_map.
|
252
|
+
direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
|
259
253
|
}
|
260
|
-
|
254
|
+
}
|
255
|
+
|
256
|
+
void IndexIVF::set_direct_map_type (DirectMap::Type type)
|
257
|
+
{
|
258
|
+
direct_map.set_type (type, invlists, ntotal);
|
261
259
|
}
|
262
260
|
|
263
261
|
|
@@ -298,10 +296,13 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
298
296
|
|
299
297
|
bool interrupt = false;
|
300
298
|
|
299
|
+
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
300
|
+
bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
|
301
|
+
|
301
302
|
// don't start parallel section if single query
|
302
303
|
bool do_parallel =
|
303
|
-
|
304
|
-
|
304
|
+
pmode == 0 ? n > 1 :
|
305
|
+
pmode == 1 ? nprobe > 1 :
|
305
306
|
nprobe * n > 1;
|
306
307
|
|
307
308
|
#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
|
@@ -318,6 +319,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
318
319
|
// intialize + reorder a result heap
|
319
320
|
|
320
321
|
auto init_result = [&](float *simi, idx_t *idxi) {
|
322
|
+
if (!do_heap_init) return;
|
321
323
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
322
324
|
heap_heapify<HeapForIP> (k, simi, idxi);
|
323
325
|
} else {
|
@@ -326,6 +328,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
326
328
|
};
|
327
329
|
|
328
330
|
auto reorder_result = [&] (float *simi, idx_t *idxi) {
|
331
|
+
if (!do_heap_init) return;
|
329
332
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
330
333
|
heap_reorder<HeapForIP> (k, simi, idxi);
|
331
334
|
} else {
|
@@ -377,7 +380,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
377
380
|
* Actual loops, depending on parallel_mode
|
378
381
|
****************************************************/
|
379
382
|
|
380
|
-
if (
|
383
|
+
if (pmode == 0) {
|
381
384
|
|
382
385
|
#pragma omp for
|
383
386
|
for (size_t i = 0; i < n; i++) {
|
@@ -417,7 +420,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
417
420
|
}
|
418
421
|
|
419
422
|
} // parallel for
|
420
|
-
} else if (
|
423
|
+
} else if (pmode == 1) {
|
421
424
|
std::vector <idx_t> local_idx (k);
|
422
425
|
std::vector <float> local_dis (k);
|
423
426
|
|
@@ -460,7 +463,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
460
463
|
}
|
461
464
|
} else {
|
462
465
|
FAISS_THROW_FMT ("parallel_mode %d not supported\n",
|
463
|
-
|
466
|
+
pmode);
|
464
467
|
}
|
465
468
|
} // parallel section
|
466
469
|
|
@@ -608,13 +611,8 @@ InvertedListScanner *IndexIVF::get_InvertedListScanner (
|
|
608
611
|
|
609
612
|
void IndexIVF::reconstruct (idx_t key, float* recons) const
|
610
613
|
{
|
611
|
-
|
612
|
-
|
613
|
-
FAISS_THROW_IF_NOT_MSG (key >= 0 && key < direct_map.size(),
|
614
|
-
"invalid key");
|
615
|
-
idx_t list_no = direct_map[key] >> 32;
|
616
|
-
idx_t offset = direct_map[key] & 0xffffffff;
|
617
|
-
reconstruct_from_offset (list_no, offset, recons);
|
614
|
+
idx_t lo = direct_map.get (key);
|
615
|
+
reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
|
618
616
|
}
|
619
617
|
|
620
618
|
|
@@ -682,8 +680,8 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
|
|
682
680
|
// Fill with NaNs
|
683
681
|
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
684
682
|
} else {
|
685
|
-
int list_no = key
|
686
|
-
int offset = key
|
683
|
+
int list_no = lo_listno (key);
|
684
|
+
int offset = lo_offset (key);
|
687
685
|
|
688
686
|
// Update label to the actual id
|
689
687
|
labels[ij] = invlists->get_single_id (list_no, offset);
|
@@ -711,42 +709,41 @@ void IndexIVF::reset ()
|
|
711
709
|
|
712
710
|
size_t IndexIVF::remove_ids (const IDSelector & sel)
|
713
711
|
{
|
714
|
-
|
715
|
-
"direct map remove not implemented");
|
716
|
-
|
717
|
-
std::vector<idx_t> toremove(nlist);
|
718
|
-
|
719
|
-
#pragma omp parallel for
|
720
|
-
for (idx_t i = 0; i < nlist; i++) {
|
721
|
-
idx_t l0 = invlists->list_size (i), l = l0, j = 0;
|
722
|
-
ScopedIds idsi (invlists, i);
|
723
|
-
while (j < l) {
|
724
|
-
if (sel.is_member (idsi[j])) {
|
725
|
-
l--;
|
726
|
-
invlists->update_entry (
|
727
|
-
i, j,
|
728
|
-
invlists->get_single_id (i, l),
|
729
|
-
ScopedCodes (invlists, i, l).get());
|
730
|
-
} else {
|
731
|
-
j++;
|
732
|
-
}
|
733
|
-
}
|
734
|
-
toremove[i] = l0 - l;
|
735
|
-
}
|
736
|
-
// this will not run well in parallel on ondisk because of possible shrinks
|
737
|
-
size_t nremove = 0;
|
738
|
-
for (idx_t i = 0; i < nlist; i++) {
|
739
|
-
if (toremove[i] > 0) {
|
740
|
-
nremove += toremove[i];
|
741
|
-
invlists->resize(
|
742
|
-
i, invlists->list_size(i) - toremove[i]);
|
743
|
-
}
|
744
|
-
}
|
712
|
+
size_t nremove = direct_map.remove_ids (sel, invlists);
|
745
713
|
ntotal -= nremove;
|
746
714
|
return nremove;
|
747
715
|
}
|
748
716
|
|
749
717
|
|
718
|
+
void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
|
719
|
+
{
|
720
|
+
|
721
|
+
if (direct_map.type == DirectMap::Hashtable) {
|
722
|
+
// just remove then add
|
723
|
+
IDSelectorArray sel(n, new_ids);
|
724
|
+
size_t nremove = remove_ids (sel);
|
725
|
+
FAISS_THROW_IF_NOT_MSG (nremove == n,
|
726
|
+
"did not find all entries to remove");
|
727
|
+
add_with_ids (n, x, new_ids);
|
728
|
+
return;
|
729
|
+
}
|
730
|
+
|
731
|
+
FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array);
|
732
|
+
// here it is more tricky because we don't want to introduce holes
|
733
|
+
// in continuous range of ids
|
734
|
+
|
735
|
+
FAISS_THROW_IF_NOT (is_trained);
|
736
|
+
std::vector<idx_t> assign (n);
|
737
|
+
quantizer->assign (n, x, assign.data());
|
738
|
+
|
739
|
+
std::vector<uint8_t> flat_codes (n * code_size);
|
740
|
+
encode_vectors (n, x, assign.data(), flat_codes.data());
|
741
|
+
|
742
|
+
direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data());
|
743
|
+
|
744
|
+
}
|
745
|
+
|
746
|
+
|
750
747
|
|
751
748
|
|
752
749
|
void IndexIVF::train (idx_t n, const float *x)
|
@@ -779,15 +776,14 @@ void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
|
|
779
776
|
FAISS_THROW_IF_NOT (other.code_size == code_size);
|
780
777
|
FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
|
781
778
|
"can only merge indexes of the same type");
|
779
|
+
FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(),
|
780
|
+
"merge direct_map not implemented");
|
782
781
|
}
|
783
782
|
|
784
783
|
|
785
784
|
void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
|
786
785
|
{
|
787
786
|
check_compatible_for_merge (other);
|
788
|
-
FAISS_THROW_IF_NOT_MSG ((!maintain_direct_map &&
|
789
|
-
!other.maintain_direct_map),
|
790
|
-
"direct map copy not implemented");
|
791
787
|
|
792
788
|
invlists->merge_from (other.invlists, add_id);
|
793
789
|
|
@@ -817,7 +813,7 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
|
817
813
|
|
818
814
|
FAISS_THROW_IF_NOT (nlist == other.nlist);
|
819
815
|
FAISS_THROW_IF_NOT (code_size == other.code_size);
|
820
|
-
FAISS_THROW_IF_NOT (
|
816
|
+
FAISS_THROW_IF_NOT (other.direct_map.no());
|
821
817
|
FAISS_THROW_IF_NOT_FMT (
|
822
818
|
subset_type == 0 || subset_type == 1 || subset_type == 2,
|
823
819
|
"subset type %d not implemented", subset_type);
|