faiss 0.3.1 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +35 -4
- data/vendor/faiss/faiss/Clustering.h +10 -1
- data/vendor/faiss/faiss/IVFlib.cpp +4 -1
- data/vendor/faiss/faiss/Index.h +21 -6
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
- data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
- data/vendor/faiss/faiss/IndexHNSW.h +52 -3
- data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVF.h +9 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
- data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.h +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
- data/vendor/faiss/faiss/impl/HNSW.h +43 -22
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
- data/vendor/faiss/faiss/impl/io.cpp +13 -5
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
- data/vendor/faiss/faiss/index_factory.cpp +31 -13
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +58 -88
- data/vendor/faiss/faiss/utils/distances.h +5 -5
- data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +10 -3
- data/vendor/faiss/faiss/utils/utils.h +3 -0
- metadata +16 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
#include <faiss/impl/HNSW.h>
|
|
9
9
|
|
|
10
|
+
#include <cstddef>
|
|
10
11
|
#include <string>
|
|
11
12
|
|
|
12
13
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
@@ -110,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
|
|
|
110
111
|
level,
|
|
111
112
|
nb_neighbors(level));
|
|
112
113
|
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
|
|
113
|
-
#pragma omp parallel for reduction(
|
|
114
|
-
|
|
114
|
+
#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
|
|
115
|
+
reduction(+ : tot_reciprocal) reduction(+ : n_node)
|
|
115
116
|
for (int i = 0; i < levels.size(); i++) {
|
|
116
117
|
if (levels[i] > level) {
|
|
117
118
|
n_node++;
|
|
@@ -215,8 +216,8 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
|
|
|
215
216
|
if (pt_level > max_level)
|
|
216
217
|
max_level = pt_level;
|
|
217
218
|
offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
|
|
218
|
-
neighbors.resize(offsets.back(), -1);
|
|
219
219
|
}
|
|
220
|
+
neighbors.resize(offsets.back(), -1);
|
|
220
221
|
|
|
221
222
|
return max_level;
|
|
222
223
|
}
|
|
@@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list(
|
|
|
229
230
|
DistanceComputer& qdis,
|
|
230
231
|
std::priority_queue<NodeDistFarther>& input,
|
|
231
232
|
std::vector<NodeDistFarther>& output,
|
|
232
|
-
int max_size
|
|
233
|
+
int max_size,
|
|
234
|
+
bool keep_max_size_level0) {
|
|
235
|
+
// This prevents number of neighbors at
|
|
236
|
+
// level 0 from being shrunk to less than 2 * M.
|
|
237
|
+
// This is essential in making sure
|
|
238
|
+
// `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
|
|
239
|
+
std::vector<NodeDistFarther> outsiders;
|
|
240
|
+
|
|
233
241
|
while (input.size() > 0) {
|
|
234
242
|
NodeDistFarther v1 = input.top();
|
|
235
243
|
input.pop();
|
|
@@ -250,8 +258,15 @@ void HNSW::shrink_neighbor_list(
|
|
|
250
258
|
if (output.size() >= max_size) {
|
|
251
259
|
return;
|
|
252
260
|
}
|
|
261
|
+
} else if (keep_max_size_level0) {
|
|
262
|
+
outsiders.push_back(v1);
|
|
253
263
|
}
|
|
254
264
|
}
|
|
265
|
+
size_t idx = 0;
|
|
266
|
+
while (keep_max_size_level0 && (output.size() < max_size) &&
|
|
267
|
+
(idx < outsiders.size())) {
|
|
268
|
+
output.push_back(outsiders[idx++]);
|
|
269
|
+
}
|
|
255
270
|
}
|
|
256
271
|
|
|
257
272
|
namespace {
|
|
@@ -268,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
|
|
|
268
283
|
void shrink_neighbor_list(
|
|
269
284
|
DistanceComputer& qdis,
|
|
270
285
|
std::priority_queue<NodeDistCloser>& resultSet1,
|
|
271
|
-
int max_size
|
|
286
|
+
int max_size,
|
|
287
|
+
bool keep_max_size_level0 = false) {
|
|
272
288
|
if (resultSet1.size() < max_size) {
|
|
273
289
|
return;
|
|
274
290
|
}
|
|
@@ -280,7 +296,8 @@ void shrink_neighbor_list(
|
|
|
280
296
|
resultSet1.pop();
|
|
281
297
|
}
|
|
282
298
|
|
|
283
|
-
HNSW::shrink_neighbor_list(
|
|
299
|
+
HNSW::shrink_neighbor_list(
|
|
300
|
+
qdis, resultSet, returnlist, max_size, keep_max_size_level0);
|
|
284
301
|
|
|
285
302
|
for (NodeDistFarther curen2 : returnlist) {
|
|
286
303
|
resultSet1.emplace(curen2.d, curen2.id);
|
|
@@ -294,7 +311,8 @@ void add_link(
|
|
|
294
311
|
DistanceComputer& qdis,
|
|
295
312
|
storage_idx_t src,
|
|
296
313
|
storage_idx_t dest,
|
|
297
|
-
int level
|
|
314
|
+
int level,
|
|
315
|
+
bool keep_max_size_level0 = false) {
|
|
298
316
|
size_t begin, end;
|
|
299
317
|
hnsw.neighbor_range(src, level, &begin, &end);
|
|
300
318
|
if (hnsw.neighbors[end - 1] == -1) {
|
|
@@ -319,7 +337,7 @@ void add_link(
|
|
|
319
337
|
resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
|
|
320
338
|
}
|
|
321
339
|
|
|
322
|
-
shrink_neighbor_list(qdis, resultSet, end - begin);
|
|
340
|
+
shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
|
|
323
341
|
|
|
324
342
|
// ...and back
|
|
325
343
|
size_t i = begin;
|
|
@@ -342,6 +360,9 @@ void search_neighbors_to_add(
|
|
|
342
360
|
float d_entry_point,
|
|
343
361
|
int level,
|
|
344
362
|
VisitedTable& vt) {
|
|
363
|
+
// selects a version
|
|
364
|
+
const bool reference_version = false;
|
|
365
|
+
|
|
345
366
|
// top is nearest candidate
|
|
346
367
|
std::priority_queue<NodeDistFarther> candidates;
|
|
347
368
|
|
|
@@ -363,59 +384,90 @@ void search_neighbors_to_add(
|
|
|
363
384
|
// loop over neighbors
|
|
364
385
|
size_t begin, end;
|
|
365
386
|
hnsw.neighbor_range(currNode, level, &begin, &end);
|
|
366
|
-
for (size_t i = begin; i < end; i++) {
|
|
367
|
-
storage_idx_t nodeId = hnsw.neighbors[i];
|
|
368
|
-
if (nodeId < 0)
|
|
369
|
-
break;
|
|
370
|
-
if (vt.get(nodeId))
|
|
371
|
-
continue;
|
|
372
|
-
vt.set(nodeId);
|
|
373
387
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
388
|
+
// select a version, based on a flag
|
|
389
|
+
if (reference_version) {
|
|
390
|
+
// a reference version
|
|
391
|
+
for (size_t i = begin; i < end; i++) {
|
|
392
|
+
storage_idx_t nodeId = hnsw.neighbors[i];
|
|
393
|
+
if (nodeId < 0)
|
|
394
|
+
break;
|
|
395
|
+
if (vt.get(nodeId))
|
|
396
|
+
continue;
|
|
397
|
+
vt.set(nodeId);
|
|
398
|
+
|
|
399
|
+
float dis = qdis(nodeId);
|
|
400
|
+
NodeDistFarther evE1(dis, nodeId);
|
|
401
|
+
|
|
402
|
+
if (results.size() < hnsw.efConstruction ||
|
|
403
|
+
results.top().d > dis) {
|
|
404
|
+
results.emplace(dis, nodeId);
|
|
405
|
+
candidates.emplace(dis, nodeId);
|
|
406
|
+
if (results.size() > hnsw.efConstruction) {
|
|
407
|
+
results.pop();
|
|
408
|
+
}
|
|
382
409
|
}
|
|
383
410
|
}
|
|
384
|
-
}
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
411
|
+
} else {
|
|
412
|
+
// a faster version
|
|
413
|
+
|
|
414
|
+
// the following version processes 4 neighbors at a time
|
|
415
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
|
416
|
+
const float dis) {
|
|
417
|
+
if (results.size() < hnsw.efConstruction ||
|
|
418
|
+
results.top().d > dis) {
|
|
419
|
+
results.emplace(dis, idx);
|
|
420
|
+
candidates.emplace(dis, idx);
|
|
421
|
+
if (results.size() > hnsw.efConstruction) {
|
|
422
|
+
results.pop();
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
};
|
|
388
426
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
**************************************************************/
|
|
427
|
+
int n_buffered = 0;
|
|
428
|
+
storage_idx_t buffered_ids[4];
|
|
392
429
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
430
|
+
for (size_t j = begin; j < end; j++) {
|
|
431
|
+
storage_idx_t nodeId = hnsw.neighbors[j];
|
|
432
|
+
if (nodeId < 0)
|
|
433
|
+
break;
|
|
434
|
+
if (vt.get(nodeId)) {
|
|
435
|
+
continue;
|
|
436
|
+
}
|
|
437
|
+
vt.set(nodeId);
|
|
438
|
+
|
|
439
|
+
buffered_ids[n_buffered] = nodeId;
|
|
440
|
+
n_buffered += 1;
|
|
441
|
+
|
|
442
|
+
if (n_buffered == 4) {
|
|
443
|
+
float dis[4];
|
|
444
|
+
qdis.distances_batch_4(
|
|
445
|
+
buffered_ids[0],
|
|
446
|
+
buffered_ids[1],
|
|
447
|
+
buffered_ids[2],
|
|
448
|
+
buffered_ids[3],
|
|
449
|
+
dis[0],
|
|
450
|
+
dis[1],
|
|
451
|
+
dis[2],
|
|
452
|
+
dis[3]);
|
|
453
|
+
|
|
454
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
|
455
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
|
456
|
+
}
|
|
402
457
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
nearest = v;
|
|
412
|
-
d_nearest = dis;
|
|
458
|
+
n_buffered = 0;
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
// process leftovers
|
|
463
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
|
464
|
+
float dis = qdis(buffered_ids[icnt]);
|
|
465
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
|
413
466
|
}
|
|
414
|
-
}
|
|
415
|
-
if (nearest == prev_nearest) {
|
|
416
|
-
return;
|
|
417
467
|
}
|
|
418
468
|
}
|
|
469
|
+
|
|
470
|
+
vt.advance();
|
|
419
471
|
}
|
|
420
472
|
|
|
421
473
|
} // namespace
|
|
@@ -429,7 +481,8 @@ void HNSW::add_links_starting_from(
|
|
|
429
481
|
float d_nearest,
|
|
430
482
|
int level,
|
|
431
483
|
omp_lock_t* locks,
|
|
432
|
-
VisitedTable& vt
|
|
484
|
+
VisitedTable& vt,
|
|
485
|
+
bool keep_max_size_level0) {
|
|
433
486
|
std::priority_queue<NodeDistCloser> link_targets;
|
|
434
487
|
|
|
435
488
|
search_neighbors_to_add(
|
|
@@ -438,13 +491,13 @@ void HNSW::add_links_starting_from(
|
|
|
438
491
|
// but we can afford only this many neighbors
|
|
439
492
|
int M = nb_neighbors(level);
|
|
440
493
|
|
|
441
|
-
::faiss::shrink_neighbor_list(ptdis, link_targets, M);
|
|
494
|
+
::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
|
|
442
495
|
|
|
443
496
|
std::vector<storage_idx_t> neighbors;
|
|
444
497
|
neighbors.reserve(link_targets.size());
|
|
445
498
|
while (!link_targets.empty()) {
|
|
446
499
|
storage_idx_t other_id = link_targets.top().id;
|
|
447
|
-
add_link(*this, ptdis, pt_id, other_id, level);
|
|
500
|
+
add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
|
|
448
501
|
neighbors.push_back(other_id);
|
|
449
502
|
link_targets.pop();
|
|
450
503
|
}
|
|
@@ -452,7 +505,7 @@ void HNSW::add_links_starting_from(
|
|
|
452
505
|
omp_unset_lock(&locks[pt_id]);
|
|
453
506
|
for (storage_idx_t other_id : neighbors) {
|
|
454
507
|
omp_set_lock(&locks[other_id]);
|
|
455
|
-
add_link(*this, ptdis, other_id, pt_id, level);
|
|
508
|
+
add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
|
|
456
509
|
omp_unset_lock(&locks[other_id]);
|
|
457
510
|
}
|
|
458
511
|
omp_set_lock(&locks[pt_id]);
|
|
@@ -467,7 +520,8 @@ void HNSW::add_with_locks(
|
|
|
467
520
|
int pt_level,
|
|
468
521
|
int pt_id,
|
|
469
522
|
std::vector<omp_lock_t>& locks,
|
|
470
|
-
VisitedTable& vt
|
|
523
|
+
VisitedTable& vt,
|
|
524
|
+
bool keep_max_size_level0) {
|
|
471
525
|
// greedy search on upper levels
|
|
472
526
|
|
|
473
527
|
storage_idx_t nearest;
|
|
@@ -496,7 +550,14 @@ void HNSW::add_with_locks(
|
|
|
496
550
|
|
|
497
551
|
for (; level >= 0; level--) {
|
|
498
552
|
add_links_starting_from(
|
|
499
|
-
ptdis,
|
|
553
|
+
ptdis,
|
|
554
|
+
pt_id,
|
|
555
|
+
nearest,
|
|
556
|
+
d_nearest,
|
|
557
|
+
level,
|
|
558
|
+
locks.data(),
|
|
559
|
+
vt,
|
|
560
|
+
keep_max_size_level0);
|
|
500
561
|
}
|
|
501
562
|
|
|
502
563
|
omp_unset_lock(&locks[pt_id]);
|
|
@@ -511,12 +572,10 @@ void HNSW::add_with_locks(
|
|
|
511
572
|
* Searching
|
|
512
573
|
**************************************************************/
|
|
513
574
|
|
|
514
|
-
namespace {
|
|
515
575
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
|
516
576
|
using Node = HNSW::Node;
|
|
517
577
|
using C = HNSW::C;
|
|
518
578
|
/** Do a BFS on the candidates list */
|
|
519
|
-
|
|
520
579
|
int search_from_candidates(
|
|
521
580
|
const HNSW& hnsw,
|
|
522
581
|
DistanceComputer& qdis,
|
|
@@ -525,8 +584,8 @@ int search_from_candidates(
|
|
|
525
584
|
VisitedTable& vt,
|
|
526
585
|
HNSWStats& stats,
|
|
527
586
|
int level,
|
|
528
|
-
int nres_in
|
|
529
|
-
const SearchParametersHNSW* params
|
|
587
|
+
int nres_in,
|
|
588
|
+
const SearchParametersHNSW* params) {
|
|
530
589
|
int nres = nres_in;
|
|
531
590
|
int ndis = 0;
|
|
532
591
|
|
|
@@ -571,27 +630,7 @@ int search_from_candidates(
|
|
|
571
630
|
size_t begin, end;
|
|
572
631
|
hnsw.neighbor_range(v0, level, &begin, &end);
|
|
573
632
|
|
|
574
|
-
//
|
|
575
|
-
// for (size_t j = begin; j < end; j++) {
|
|
576
|
-
// int v1 = hnsw.neighbors[j];
|
|
577
|
-
// if (v1 < 0)
|
|
578
|
-
// break;
|
|
579
|
-
// if (vt.get(v1)) {
|
|
580
|
-
// continue;
|
|
581
|
-
// }
|
|
582
|
-
// vt.set(v1);
|
|
583
|
-
// ndis++;
|
|
584
|
-
// float d = qdis(v1);
|
|
585
|
-
// if (!sel || sel->is_member(v1)) {
|
|
586
|
-
// if (nres < k) {
|
|
587
|
-
// faiss::maxheap_push(++nres, D, I, d, v1);
|
|
588
|
-
// } else if (d < D[0]) {
|
|
589
|
-
// faiss::maxheap_replace_top(nres, D, I, d, v1);
|
|
590
|
-
// }
|
|
591
|
-
// }
|
|
592
|
-
// candidates.push(v1, d);
|
|
593
|
-
// }
|
|
594
|
-
|
|
633
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
|
595
634
|
// the following version processes 4 neighbors at a time
|
|
596
635
|
size_t jmax = begin;
|
|
597
636
|
for (size_t j = begin; j < end; j++) {
|
|
@@ -606,7 +645,6 @@ int search_from_candidates(
|
|
|
606
645
|
int counter = 0;
|
|
607
646
|
size_t saved_j[4];
|
|
608
647
|
|
|
609
|
-
ndis += jmax - begin;
|
|
610
648
|
threshold = res.threshold;
|
|
611
649
|
|
|
612
650
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
|
@@ -614,6 +652,7 @@ int search_from_candidates(
|
|
|
614
652
|
if (dis < threshold) {
|
|
615
653
|
if (res.add_result(dis, idx)) {
|
|
616
654
|
threshold = res.threshold;
|
|
655
|
+
nres += 1;
|
|
617
656
|
}
|
|
618
657
|
}
|
|
619
658
|
}
|
|
@@ -644,6 +683,8 @@ int search_from_candidates(
|
|
|
644
683
|
add_to_heap(saved_j[id4], dis[id4]);
|
|
645
684
|
}
|
|
646
685
|
|
|
686
|
+
ndis += 4;
|
|
687
|
+
|
|
647
688
|
counter = 0;
|
|
648
689
|
}
|
|
649
690
|
}
|
|
@@ -651,6 +692,8 @@ int search_from_candidates(
|
|
|
651
692
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
|
652
693
|
float dis = qdis(saved_j[icnt]);
|
|
653
694
|
add_to_heap(saved_j[icnt], dis);
|
|
695
|
+
|
|
696
|
+
ndis += 1;
|
|
654
697
|
}
|
|
655
698
|
|
|
656
699
|
nstep++;
|
|
@@ -664,7 +707,8 @@ int search_from_candidates(
|
|
|
664
707
|
if (candidates.size() == 0) {
|
|
665
708
|
stats.n2++;
|
|
666
709
|
}
|
|
667
|
-
stats.
|
|
710
|
+
stats.ndis += ndis;
|
|
711
|
+
stats.nhops += nstep;
|
|
668
712
|
}
|
|
669
713
|
|
|
670
714
|
return nres;
|
|
@@ -700,33 +744,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
|
700
744
|
size_t begin, end;
|
|
701
745
|
hnsw.neighbor_range(v0, 0, &begin, &end);
|
|
702
746
|
|
|
703
|
-
//
|
|
704
|
-
// for (size_t j = begin; j < end; ++j) {
|
|
705
|
-
// int v1 = hnsw.neighbors[j];
|
|
706
|
-
//
|
|
707
|
-
// if (v1 < 0) {
|
|
708
|
-
// break;
|
|
709
|
-
// }
|
|
710
|
-
// if (vt->get(v1)) {
|
|
711
|
-
// continue;
|
|
712
|
-
// }
|
|
713
|
-
//
|
|
714
|
-
// vt->set(v1);
|
|
715
|
-
//
|
|
716
|
-
// float d1 = qdis(v1);
|
|
717
|
-
// ++ndis;
|
|
718
|
-
//
|
|
719
|
-
// if (top_candidates.top().first > d1 ||
|
|
720
|
-
// top_candidates.size() < ef) {
|
|
721
|
-
// candidates.emplace(d1, v1);
|
|
722
|
-
// top_candidates.emplace(d1, v1);
|
|
723
|
-
//
|
|
724
|
-
// if (top_candidates.size() > ef) {
|
|
725
|
-
// top_candidates.pop();
|
|
726
|
-
// }
|
|
727
|
-
// }
|
|
728
|
-
// }
|
|
729
|
-
|
|
747
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
|
730
748
|
// the following version processes 4 neighbors at a time
|
|
731
749
|
size_t jmax = begin;
|
|
732
750
|
for (size_t j = begin; j < end; j++) {
|
|
@@ -741,8 +759,6 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
|
741
759
|
int counter = 0;
|
|
742
760
|
size_t saved_j[4];
|
|
743
761
|
|
|
744
|
-
ndis += jmax - begin;
|
|
745
|
-
|
|
746
762
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
|
747
763
|
if (top_candidates.top().first > dis ||
|
|
748
764
|
top_candidates.size() < ef) {
|
|
@@ -779,6 +795,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
|
779
795
|
add_to_heap(saved_j[id4], dis[id4]);
|
|
780
796
|
}
|
|
781
797
|
|
|
798
|
+
ndis += 4;
|
|
799
|
+
|
|
782
800
|
counter = 0;
|
|
783
801
|
}
|
|
784
802
|
}
|
|
@@ -786,18 +804,102 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
|
786
804
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
|
787
805
|
float dis = qdis(saved_j[icnt]);
|
|
788
806
|
add_to_heap(saved_j[icnt], dis);
|
|
807
|
+
|
|
808
|
+
ndis += 1;
|
|
789
809
|
}
|
|
810
|
+
|
|
811
|
+
stats.nhops += 1;
|
|
790
812
|
}
|
|
791
813
|
|
|
792
814
|
++stats.n1;
|
|
793
815
|
if (candidates.size() == 0) {
|
|
794
816
|
++stats.n2;
|
|
795
817
|
}
|
|
796
|
-
stats.
|
|
818
|
+
stats.ndis += ndis;
|
|
797
819
|
|
|
798
820
|
return top_candidates;
|
|
799
821
|
}
|
|
800
822
|
|
|
823
|
+
/// greedily update a nearest vector at a given level
|
|
824
|
+
HNSWStats greedy_update_nearest(
|
|
825
|
+
const HNSW& hnsw,
|
|
826
|
+
DistanceComputer& qdis,
|
|
827
|
+
int level,
|
|
828
|
+
storage_idx_t& nearest,
|
|
829
|
+
float& d_nearest) {
|
|
830
|
+
HNSWStats stats;
|
|
831
|
+
|
|
832
|
+
for (;;) {
|
|
833
|
+
storage_idx_t prev_nearest = nearest;
|
|
834
|
+
|
|
835
|
+
size_t begin, end;
|
|
836
|
+
hnsw.neighbor_range(nearest, level, &begin, &end);
|
|
837
|
+
|
|
838
|
+
size_t ndis = 0;
|
|
839
|
+
|
|
840
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
|
841
|
+
// the following version processes 4 neighbors at a time
|
|
842
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
|
843
|
+
const float dis) {
|
|
844
|
+
if (dis < d_nearest) {
|
|
845
|
+
nearest = idx;
|
|
846
|
+
d_nearest = dis;
|
|
847
|
+
}
|
|
848
|
+
};
|
|
849
|
+
|
|
850
|
+
int n_buffered = 0;
|
|
851
|
+
storage_idx_t buffered_ids[4];
|
|
852
|
+
|
|
853
|
+
for (size_t j = begin; j < end; j++) {
|
|
854
|
+
storage_idx_t v = hnsw.neighbors[j];
|
|
855
|
+
if (v < 0)
|
|
856
|
+
break;
|
|
857
|
+
ndis += 1;
|
|
858
|
+
|
|
859
|
+
buffered_ids[n_buffered] = v;
|
|
860
|
+
n_buffered += 1;
|
|
861
|
+
|
|
862
|
+
if (n_buffered == 4) {
|
|
863
|
+
float dis[4];
|
|
864
|
+
qdis.distances_batch_4(
|
|
865
|
+
buffered_ids[0],
|
|
866
|
+
buffered_ids[1],
|
|
867
|
+
buffered_ids[2],
|
|
868
|
+
buffered_ids[3],
|
|
869
|
+
dis[0],
|
|
870
|
+
dis[1],
|
|
871
|
+
dis[2],
|
|
872
|
+
dis[3]);
|
|
873
|
+
|
|
874
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
|
875
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
n_buffered = 0;
|
|
879
|
+
}
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
// process leftovers
|
|
883
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
|
884
|
+
float dis = qdis(buffered_ids[icnt]);
|
|
885
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
// update stats
|
|
889
|
+
stats.ndis += ndis;
|
|
890
|
+
stats.nhops += 1;
|
|
891
|
+
|
|
892
|
+
if (nearest == prev_nearest) {
|
|
893
|
+
return stats;
|
|
894
|
+
}
|
|
895
|
+
}
|
|
896
|
+
}
|
|
897
|
+
|
|
898
|
+
namespace {
|
|
899
|
+
using MinimaxHeap = HNSW::MinimaxHeap;
|
|
900
|
+
using Node = HNSW::Node;
|
|
901
|
+
using C = HNSW::C;
|
|
902
|
+
|
|
801
903
|
// just used as a lower bound for the minmaxheap, but it is set for heap search
|
|
802
904
|
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
|
803
905
|
using RH = HeapBlockResultHandler<C>;
|
|
@@ -807,7 +909,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
|
|
807
909
|
return 1;
|
|
808
910
|
}
|
|
809
911
|
|
|
810
|
-
} //
|
|
912
|
+
} // namespace
|
|
811
913
|
|
|
812
914
|
HNSWStats HNSW::search(
|
|
813
915
|
DistanceComputer& qdis,
|
|
@@ -820,85 +922,47 @@ HNSWStats HNSW::search(
|
|
|
820
922
|
}
|
|
821
923
|
int k = extract_k_from_ResultHandler(res);
|
|
822
924
|
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
storage_idx_t nearest = entry_point;
|
|
826
|
-
float d_nearest = qdis(nearest);
|
|
827
|
-
|
|
828
|
-
for (int level = max_level; level >= 1; level--) {
|
|
829
|
-
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
|
830
|
-
}
|
|
925
|
+
bool bounded_queue =
|
|
926
|
+
params ? params->bounded_queue : this->search_bounded_queue;
|
|
831
927
|
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
candidates.push(nearest, d_nearest);
|
|
928
|
+
// greedy search on upper levels
|
|
929
|
+
storage_idx_t nearest = entry_point;
|
|
930
|
+
float d_nearest = qdis(nearest);
|
|
837
931
|
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
*this,
|
|
844
|
-
Node(d_nearest, nearest),
|
|
845
|
-
qdis,
|
|
846
|
-
ef,
|
|
847
|
-
&vt,
|
|
848
|
-
stats);
|
|
849
|
-
|
|
850
|
-
while (top_candidates.size() > k) {
|
|
851
|
-
top_candidates.pop();
|
|
852
|
-
}
|
|
932
|
+
for (int level = max_level; level >= 1; level--) {
|
|
933
|
+
HNSWStats local_stats =
|
|
934
|
+
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
|
935
|
+
stats.combine(local_stats);
|
|
936
|
+
}
|
|
853
937
|
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
std::tie(d, label) = top_candidates.top();
|
|
858
|
-
res.add_result(d, label);
|
|
859
|
-
top_candidates.pop();
|
|
860
|
-
}
|
|
861
|
-
}
|
|
938
|
+
int ef = std::max(params ? params->efSearch : efSearch, k);
|
|
939
|
+
if (bounded_queue) { // this is the most common branch
|
|
940
|
+
MinimaxHeap candidates(ef);
|
|
862
941
|
|
|
863
|
-
|
|
942
|
+
candidates.push(nearest, d_nearest);
|
|
864
943
|
|
|
944
|
+
search_from_candidates(
|
|
945
|
+
*this, qdis, res, candidates, vt, stats, 0, 0, params);
|
|
865
946
|
} else {
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
std::vector<idx_t> I_to_next(candidates_size);
|
|
870
|
-
std::vector<float> D_to_next(candidates_size);
|
|
871
|
-
|
|
872
|
-
HeapBlockResultHandler<C> block_resh(
|
|
873
|
-
1, D_to_next.data(), I_to_next.data(), candidates_size);
|
|
874
|
-
HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
|
|
875
|
-
|
|
876
|
-
int nres = 1;
|
|
877
|
-
I_to_next[0] = entry_point;
|
|
878
|
-
D_to_next[0] = qdis(entry_point);
|
|
879
|
-
|
|
880
|
-
for (int level = max_level; level >= 0; level--) {
|
|
881
|
-
// copy I, D -> candidates
|
|
882
|
-
|
|
883
|
-
candidates.clear();
|
|
947
|
+
std::priority_queue<Node> top_candidates =
|
|
948
|
+
search_from_candidate_unbounded(
|
|
949
|
+
*this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
|
|
884
950
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
951
|
+
while (top_candidates.size() > k) {
|
|
952
|
+
top_candidates.pop();
|
|
953
|
+
}
|
|
888
954
|
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
*this, qdis, resh, candidates, vt, stats, level);
|
|
896
|
-
resh.end();
|
|
897
|
-
}
|
|
898
|
-
vt.advance();
|
|
955
|
+
while (!top_candidates.empty()) {
|
|
956
|
+
float d;
|
|
957
|
+
storage_idx_t label;
|
|
958
|
+
std::tie(d, label) = top_candidates.top();
|
|
959
|
+
res.add_result(d, label);
|
|
960
|
+
top_candidates.pop();
|
|
899
961
|
}
|
|
900
962
|
}
|
|
901
963
|
|
|
964
|
+
vt.advance();
|
|
965
|
+
|
|
902
966
|
return stats;
|
|
903
967
|
}
|
|
904
968
|
|
|
@@ -910,9 +974,12 @@ void HNSW::search_level_0(
|
|
|
910
974
|
const float* nearest_d,
|
|
911
975
|
int search_type,
|
|
912
976
|
HNSWStats& search_stats,
|
|
913
|
-
VisitedTable& vt
|
|
977
|
+
VisitedTable& vt,
|
|
978
|
+
const SearchParametersHNSW* params) const {
|
|
914
979
|
const HNSW& hnsw = *this;
|
|
980
|
+
auto efSearch = params ? params->efSearch : hnsw.efSearch;
|
|
915
981
|
int k = extract_k_from_ResultHandler(res);
|
|
982
|
+
|
|
916
983
|
if (search_type == 1) {
|
|
917
984
|
int nres = 0;
|
|
918
985
|
|
|
@@ -925,16 +992,25 @@ void HNSW::search_level_0(
|
|
|
925
992
|
if (vt.get(cj))
|
|
926
993
|
continue;
|
|
927
994
|
|
|
928
|
-
int candidates_size = std::max(
|
|
995
|
+
int candidates_size = std::max(efSearch, k);
|
|
929
996
|
MinimaxHeap candidates(candidates_size);
|
|
930
997
|
|
|
931
998
|
candidates.push(cj, nearest_d[j]);
|
|
932
999
|
|
|
933
1000
|
nres = search_from_candidates(
|
|
934
|
-
hnsw,
|
|
1001
|
+
hnsw,
|
|
1002
|
+
qdis,
|
|
1003
|
+
res,
|
|
1004
|
+
candidates,
|
|
1005
|
+
vt,
|
|
1006
|
+
search_stats,
|
|
1007
|
+
0,
|
|
1008
|
+
nres,
|
|
1009
|
+
params);
|
|
1010
|
+
nres = std::min(nres, candidates_size);
|
|
935
1011
|
}
|
|
936
1012
|
} else if (search_type == 2) {
|
|
937
|
-
int candidates_size = std::max(
|
|
1013
|
+
int candidates_size = std::max(efSearch, int(k));
|
|
938
1014
|
candidates_size = std::max(candidates_size, int(nprobe));
|
|
939
1015
|
|
|
940
1016
|
MinimaxHeap candidates(candidates_size);
|
|
@@ -947,7 +1023,7 @@ void HNSW::search_level_0(
|
|
|
947
1023
|
}
|
|
948
1024
|
|
|
949
1025
|
search_from_candidates(
|
|
950
|
-
hnsw, qdis, res, candidates, vt, search_stats, 0);
|
|
1026
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
|
|
951
1027
|
}
|
|
952
1028
|
}
|
|
953
1029
|
|
|
@@ -1013,7 +1089,99 @@ void HNSW::MinimaxHeap::clear() {
|
|
|
1013
1089
|
nvalid = k = 0;
|
|
1014
1090
|
}
|
|
1015
1091
|
|
|
1016
|
-
#ifdef
|
|
1092
|
+
#ifdef __AVX512F__
|
|
1093
|
+
|
|
1094
|
+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
1095
|
+
assert(k > 0);
|
|
1096
|
+
static_assert(
|
|
1097
|
+
std::is_same<storage_idx_t, int32_t>::value,
|
|
1098
|
+
"This code expects storage_idx_t to be int32_t");
|
|
1099
|
+
|
|
1100
|
+
int32_t min_idx = -1;
|
|
1101
|
+
float min_dis = std::numeric_limits<float>::infinity();
|
|
1102
|
+
|
|
1103
|
+
__m512i min_indices = _mm512_set1_epi32(-1);
|
|
1104
|
+
__m512 min_distances =
|
|
1105
|
+
_mm512_set1_ps(std::numeric_limits<float>::infinity());
|
|
1106
|
+
__m512i current_indices = _mm512_setr_epi32(
|
|
1107
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1108
|
+
__m512i offset = _mm512_set1_epi32(16);
|
|
1109
|
+
|
|
1110
|
+
// The following loop tracks the rightmost index with the min distance.
|
|
1111
|
+
// -1 index values are ignored.
|
|
1112
|
+
const int k16 = (k / 16) * 16;
|
|
1113
|
+
for (size_t iii = 0; iii < k16; iii += 16) {
|
|
1114
|
+
__m512i indices =
|
|
1115
|
+
_mm512_loadu_si512((const __m512i*)(ids.data() + iii));
|
|
1116
|
+
__m512 distances = _mm512_loadu_ps(dis.data() + iii);
|
|
1117
|
+
|
|
1118
|
+
// This mask filters out -1 values among indices.
|
|
1119
|
+
__mmask16 m1mask =
|
|
1120
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
|
1121
|
+
|
|
1122
|
+
__mmask16 dmask =
|
|
1123
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
|
1124
|
+
__mmask16 finalmask = m1mask | dmask;
|
|
1125
|
+
|
|
1126
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
|
1127
|
+
finalmask, current_indices, min_indices);
|
|
1128
|
+
const __m512 min_distances_new =
|
|
1129
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
|
1130
|
+
|
|
1131
|
+
min_indices = min_indices_new;
|
|
1132
|
+
min_distances = min_distances_new;
|
|
1133
|
+
|
|
1134
|
+
current_indices = _mm512_add_epi32(current_indices, offset);
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
// leftovers
|
|
1138
|
+
if (k16 != k) {
|
|
1139
|
+
const __mmask16 kmask = (1 << (k - k16)) - 1;
|
|
1140
|
+
|
|
1141
|
+
__m512i indices = _mm512_mask_loadu_epi32(
|
|
1142
|
+
_mm512_set1_epi32(-1), kmask, ids.data() + k16);
|
|
1143
|
+
__m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
|
|
1144
|
+
|
|
1145
|
+
// This mask filters out -1 values among indices.
|
|
1146
|
+
__mmask16 m1mask =
|
|
1147
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
|
1148
|
+
|
|
1149
|
+
__mmask16 dmask =
|
|
1150
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
|
1151
|
+
__mmask16 finalmask = m1mask | dmask;
|
|
1152
|
+
|
|
1153
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
|
1154
|
+
finalmask, current_indices, min_indices);
|
|
1155
|
+
const __m512 min_distances_new =
|
|
1156
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
|
1157
|
+
|
|
1158
|
+
min_indices = min_indices_new;
|
|
1159
|
+
min_distances = min_distances_new;
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
// grab min distance
|
|
1163
|
+
min_dis = _mm512_reduce_min_ps(min_distances);
|
|
1164
|
+
// blend
|
|
1165
|
+
__mmask16 mindmask =
|
|
1166
|
+
_mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
|
|
1167
|
+
// pick the max one
|
|
1168
|
+
min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
|
|
1169
|
+
|
|
1170
|
+
if (min_idx == -1) {
|
|
1171
|
+
return -1;
|
|
1172
|
+
}
|
|
1173
|
+
|
|
1174
|
+
if (vmin_out) {
|
|
1175
|
+
*vmin_out = min_dis;
|
|
1176
|
+
}
|
|
1177
|
+
int ret = ids[min_idx];
|
|
1178
|
+
ids[min_idx] = -1;
|
|
1179
|
+
--nvalid;
|
|
1180
|
+
return ret;
|
|
1181
|
+
}
|
|
1182
|
+
|
|
1183
|
+
#elif __AVX2__
|
|
1184
|
+
|
|
1017
1185
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
1018
1186
|
assert(k > 0);
|
|
1019
1187
|
static_assert(
|