faiss 0.2.6 → 0.2.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +2 -2
- data/vendor/faiss/faiss/AutoTune.cpp +15 -4
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +1 -5
- data/vendor/faiss/faiss/Clustering.h +0 -2
- data/vendor/faiss/faiss/IVFlib.h +0 -2
- data/vendor/faiss/faiss/Index.h +1 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
- data/vendor/faiss/faiss/IndexBinary.h +0 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
- data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
- data/vendor/faiss/faiss/IndexFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
- data/vendor/faiss/faiss/IndexHNSW.h +0 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
- data/vendor/faiss/faiss/IndexIDMap.h +0 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
- data/vendor/faiss/faiss/IndexIVF.h +121 -61
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
- data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
- data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
- data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
- data/vendor/faiss/faiss/IndexReplicas.h +0 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
- data/vendor/faiss/faiss/IndexShards.cpp +26 -109
- data/vendor/faiss/faiss/IndexShards.h +2 -3
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
- data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
- data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
- data/vendor/faiss/faiss/MetaIndexes.h +29 -0
- data/vendor/faiss/faiss/MetricType.h +14 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
- data/vendor/faiss/faiss/VectorTransform.h +1 -3
- data/vendor/faiss/faiss/clone_index.cpp +232 -18
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
- data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
- data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
- data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
- data/vendor/faiss/faiss/impl/HNSW.h +6 -9
- data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
- data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
- data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
- data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
- data/vendor/faiss/faiss/impl/NSG.h +4 -7
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
- data/vendor/faiss/faiss/index_factory.cpp +8 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
- data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
- data/vendor/faiss/faiss/utils/Heap.h +35 -1
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
- data/vendor/faiss/faiss/utils/distances.cpp +61 -7
- data/vendor/faiss/faiss/utils/distances.h +11 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
- data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
- data/vendor/faiss/faiss/utils/fp16.h +7 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
- data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
- data/vendor/faiss/faiss/utils/hamming.h +21 -10
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
- data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
- data/vendor/faiss/faiss/utils/sorting.h +71 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
- data/vendor/faiss/faiss/utils/utils.cpp +4 -176
- data/vendor/faiss/faiss/utils/utils.h +2 -9
- metadata +29 -3
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,84 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// This file contains an implementation of approximate top-k search
|
9
|
+
// using heap. It was initially created for a beam search.
|
10
|
+
//
|
11
|
+
// The core idea is the following.
|
12
|
+
// Say we need to find beam_size indices with the minimal distance
|
13
|
+
// values. It is done via heap (priority_queue) using the following
|
14
|
+
// pseudocode:
|
15
|
+
//
|
16
|
+
// def baseline():
|
17
|
+
// distances = np.empty([beam_size * n], dtype=float)
|
18
|
+
// indices = np.empty([beam_size * n], dtype=int)
|
19
|
+
//
|
20
|
+
// heap = Heap(max_heap_size=beam_size)
|
21
|
+
//
|
22
|
+
// for i in range(0, beam_size * n):
|
23
|
+
// heap.push(distances[i], indices[i])
|
24
|
+
//
|
25
|
+
// Basically, this is what heap_addn() function from utils/Heap.h does.
|
26
|
+
//
|
27
|
+
// The following scheme can be used for approximate beam search.
|
28
|
+
// Say, we need to find elements with min distance.
|
29
|
+
// Basically, we split n elements of every beam into NBUCKETS buckets
|
30
|
+
// and track the index with the minimal distance for every bucket.
|
31
|
+
// This can be effectively SIMD-ed and significantly lowers the number
|
32
|
+
// of operations, but yields approximate results for beam_size >= 2.
|
33
|
+
//
|
34
|
+
// def approximate_v1():
|
35
|
+
// distances = np.empty([beam_size * n], dtype=float)
|
36
|
+
// indices = np.empty([beam_size * n], dtype=int)
|
37
|
+
//
|
38
|
+
// heap = Heap(max_heap_size=beam_size)
|
39
|
+
//
|
40
|
+
// for beam in range(0, beam_size):
|
41
|
+
// # The value of 32 is just an example.
|
42
|
+
// # The value may be varied: the larger the value is,
|
43
|
+
// # the slower and the more precise vs baseline beam search is
|
44
|
+
// NBUCKETS = 32
|
45
|
+
//
|
46
|
+
// local_min_distances = [HUGE_VALF] * NBUCKETS
|
47
|
+
// local_min_indices = [0] * NBUCKETS
|
48
|
+
//
|
49
|
+
// for i in range(0, n / NBUCKETS):
|
50
|
+
// for j in range(0, NBUCKETS):
|
51
|
+
// idx = beam * n + i * NBUCKETS + j
|
52
|
+
// if distances[idx] < local_min_distances[j]:
|
53
|
+
// local_min_distances[i] = distances[idx]
|
54
|
+
// local_min_indices[i] = indices[idx]
|
55
|
+
//
|
56
|
+
// for j in range(0, NBUCKETS):
|
57
|
+
// heap.push(local_min_distances[j], local_min_indices[j])
|
58
|
+
//
|
59
|
+
// The accuracy can be improved by tracking min-2 elements for every
|
60
|
+
// bucket. Such a min-2 implementation with NBUCKETS buckets provides
|
61
|
+
// better accuracy than top-1 implementation with 2 * NBUCKETS buckets.
|
62
|
+
// Min-3 is also doable. One can use min-N approach, but I'm not sure
|
63
|
+
// whether min-4 and above are practical, because of the lack of SIMD
|
64
|
+
// registers (unless AVX-512 version is used).
|
65
|
+
//
|
66
|
+
// C++ template for top-N implementation is provided. The code
|
67
|
+
// assumes that indices[idx] == idx. One can write a code that lifts
|
68
|
+
// such an assumption easily.
|
69
|
+
//
|
70
|
+
// Currently, the code that tracks elements with min distances is implemented
|
71
|
+
// (Max Heap). Min Heap option can be added easily.
|
72
|
+
|
73
|
+
#pragma once
|
74
|
+
|
75
|
+
#include <faiss/impl/platform_macros.h>
|
76
|
+
|
77
|
+
// the list of available modes is in the following file
|
78
|
+
#include <faiss/utils/approx_topk/mode.h>
|
79
|
+
|
80
|
+
#ifdef __AVX2__
|
81
|
+
#include <faiss/utils/approx_topk/avx2-inl.h>
|
82
|
+
#else
|
83
|
+
#include <faiss/utils/approx_topk/generic.h>
|
84
|
+
#endif
|
@@ -0,0 +1,196 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <immintrin.h>
|
11
|
+
|
12
|
+
#include <limits>
|
13
|
+
|
14
|
+
#include <faiss/impl/FaissAssert.h>
|
15
|
+
#include <faiss/utils/Heap.h>
|
16
|
+
|
17
|
+
namespace faiss {
|
18
|
+
|
19
|
+
template <typename C, uint32_t NBUCKETS, uint32_t N>
|
20
|
+
struct HeapWithBuckets {
|
21
|
+
// this case was not implemented yet.
|
22
|
+
};
|
23
|
+
|
24
|
+
template <uint32_t NBUCKETS, uint32_t N>
|
25
|
+
struct HeapWithBuckets<CMax<float, int>, NBUCKETS, N> {
|
26
|
+
static constexpr uint32_t NBUCKETS_8 = NBUCKETS / 8;
|
27
|
+
static_assert(
|
28
|
+
(NBUCKETS) > 0 && ((NBUCKETS % 8) == 0),
|
29
|
+
"Number of buckets needs to be 8, 16, 24, ...");
|
30
|
+
|
31
|
+
static void addn(
|
32
|
+
// number of elements
|
33
|
+
const uint32_t n,
|
34
|
+
// distances. It is assumed to have n elements.
|
35
|
+
const float* const __restrict distances,
|
36
|
+
// number of best elements to keep
|
37
|
+
const uint32_t k,
|
38
|
+
// output distances
|
39
|
+
float* const __restrict bh_val,
|
40
|
+
// output indices, each being within [0, n) range
|
41
|
+
int32_t* const __restrict bh_ids) {
|
42
|
+
// forward a call to bs_addn with 1 beam
|
43
|
+
bs_addn(1, n, distances, k, bh_val, bh_ids);
|
44
|
+
}
|
45
|
+
|
46
|
+
static void bs_addn(
|
47
|
+
// beam_size parameter of Beam Search algorithm
|
48
|
+
const uint32_t beam_size,
|
49
|
+
// number of elements per beam
|
50
|
+
const uint32_t n_per_beam,
|
51
|
+
// distances. It is assumed to have (n_per_beam * beam_size)
|
52
|
+
// elements.
|
53
|
+
const float* const __restrict distances,
|
54
|
+
// number of best elements to keep
|
55
|
+
const uint32_t k,
|
56
|
+
// output distances
|
57
|
+
float* const __restrict bh_val,
|
58
|
+
// output indices, each being within [0, n_per_beam * beam_size)
|
59
|
+
// range
|
60
|
+
int32_t* const __restrict bh_ids) {
|
61
|
+
// // Basically, the function runs beam_size iterations.
|
62
|
+
// // Every iteration NBUCKETS * N elements are added to a regular heap.
|
63
|
+
// // So, maximum number of added elements is beam_size * NBUCKETS * N.
|
64
|
+
// // This number is expected to be less or equal than k.
|
65
|
+
// FAISS_THROW_IF_NOT_FMT(
|
66
|
+
// beam_size * NBUCKETS * N >= k,
|
67
|
+
// "Cannot pick %d elements, only %d. "
|
68
|
+
// "Check the function and template arguments values.",
|
69
|
+
// k,
|
70
|
+
// beam_size * NBUCKETS * N);
|
71
|
+
|
72
|
+
using C = CMax<float, int>;
|
73
|
+
|
74
|
+
// main loop
|
75
|
+
for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
|
76
|
+
__m256 min_distances_i[NBUCKETS_8][N];
|
77
|
+
__m256i min_indices_i[NBUCKETS_8][N];
|
78
|
+
|
79
|
+
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
|
80
|
+
for (uint32_t p = 0; p < N; p++) {
|
81
|
+
min_distances_i[j][p] =
|
82
|
+
_mm256_set1_ps(std::numeric_limits<float>::max());
|
83
|
+
min_indices_i[j][p] =
|
84
|
+
_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
85
|
+
}
|
86
|
+
}
|
87
|
+
|
88
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
89
|
+
__m256i indices_delta = _mm256_set1_epi32(NBUCKETS);
|
90
|
+
|
91
|
+
const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
|
92
|
+
|
93
|
+
// put the data into buckets
|
94
|
+
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
|
95
|
+
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
|
96
|
+
const __m256 distances_reg = _mm256_loadu_ps(
|
97
|
+
distances + j * 8 + ip + n_per_beam * beam_index);
|
98
|
+
|
99
|
+
// loop. Compiler should get rid of unneeded ops
|
100
|
+
__m256 distance_candidate = distances_reg;
|
101
|
+
__m256i indices_candidate = current_indices;
|
102
|
+
|
103
|
+
for (uint32_t p = 0; p < N; p++) {
|
104
|
+
const __m256 comparison = _mm256_cmp_ps(
|
105
|
+
min_distances_i[j][p],
|
106
|
+
distance_candidate,
|
107
|
+
_CMP_LE_OS);
|
108
|
+
|
109
|
+
// // blend seems to be slower that min
|
110
|
+
// const __m256 min_distances_new = _mm256_blendv_ps(
|
111
|
+
// distance_candidate,
|
112
|
+
// min_distances_i[j][p],
|
113
|
+
// comparison);
|
114
|
+
const __m256 min_distances_new = _mm256_min_ps(
|
115
|
+
distance_candidate, min_distances_i[j][p]);
|
116
|
+
const __m256i min_indices_new =
|
117
|
+
_mm256_castps_si256(_mm256_blendv_ps(
|
118
|
+
_mm256_castsi256_ps(indices_candidate),
|
119
|
+
_mm256_castsi256_ps(
|
120
|
+
min_indices_i[j][p]),
|
121
|
+
comparison));
|
122
|
+
|
123
|
+
// // blend seems to be slower that min
|
124
|
+
// const __m256 max_distances_new = _mm256_blendv_ps(
|
125
|
+
// min_distances_i[j][p],
|
126
|
+
// distance_candidate,
|
127
|
+
// comparison);
|
128
|
+
const __m256 max_distances_new = _mm256_max_ps(
|
129
|
+
min_distances_i[j][p], distances_reg);
|
130
|
+
const __m256i max_indices_new =
|
131
|
+
_mm256_castps_si256(_mm256_blendv_ps(
|
132
|
+
_mm256_castsi256_ps(
|
133
|
+
min_indices_i[j][p]),
|
134
|
+
_mm256_castsi256_ps(indices_candidate),
|
135
|
+
comparison));
|
136
|
+
|
137
|
+
distance_candidate = max_distances_new;
|
138
|
+
indices_candidate = max_indices_new;
|
139
|
+
|
140
|
+
min_distances_i[j][p] = min_distances_new;
|
141
|
+
min_indices_i[j][p] = min_indices_new;
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
current_indices =
|
146
|
+
_mm256_add_epi32(current_indices, indices_delta);
|
147
|
+
}
|
148
|
+
|
149
|
+
// fix the indices
|
150
|
+
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
|
151
|
+
const __m256i offset =
|
152
|
+
_mm256_set1_epi32(n_per_beam * beam_index + j * 8);
|
153
|
+
for (uint32_t p = 0; p < N; p++) {
|
154
|
+
min_indices_i[j][p] =
|
155
|
+
_mm256_add_epi32(min_indices_i[j][p], offset);
|
156
|
+
}
|
157
|
+
}
|
158
|
+
|
159
|
+
// merge every bucket into the regular heap
|
160
|
+
for (uint32_t p = 0; p < N; p++) {
|
161
|
+
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
|
162
|
+
int32_t min_indices_scalar[8];
|
163
|
+
float min_distances_scalar[8];
|
164
|
+
|
165
|
+
_mm256_storeu_si256(
|
166
|
+
(__m256i*)min_indices_scalar, min_indices_i[j][p]);
|
167
|
+
_mm256_storeu_ps(
|
168
|
+
min_distances_scalar, min_distances_i[j][p]);
|
169
|
+
|
170
|
+
// this exact way is needed to maintain the order as if the
|
171
|
+
// input elements were pushed to the heap sequentially
|
172
|
+
for (size_t j8 = 0; j8 < 8; j8++) {
|
173
|
+
const auto value = min_distances_scalar[j8];
|
174
|
+
const auto index = min_indices_scalar[j8];
|
175
|
+
if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
|
176
|
+
heap_replace_top<C>(
|
177
|
+
k, bh_val, bh_ids, value, index);
|
178
|
+
}
|
179
|
+
}
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
183
|
+
// process leftovers
|
184
|
+
for (uint32_t ip = nb; ip < n_per_beam; ip++) {
|
185
|
+
const int32_t index = ip + n_per_beam * beam_index;
|
186
|
+
const float value = distances[index];
|
187
|
+
|
188
|
+
if (C::cmp(bh_val[0], value)) {
|
189
|
+
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
|
190
|
+
}
|
191
|
+
}
|
192
|
+
}
|
193
|
+
}
|
194
|
+
};
|
195
|
+
|
196
|
+
} // namespace faiss
|
@@ -0,0 +1,138 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <algorithm>
|
11
|
+
#include <limits>
|
12
|
+
#include <utility>
|
13
|
+
|
14
|
+
#include <faiss/impl/FaissAssert.h>
|
15
|
+
#include <faiss/utils/Heap.h>
|
16
|
+
|
17
|
+
namespace faiss {
|
18
|
+
|
19
|
+
// This is the implementation of the idea and it is very slow,
|
20
|
+
// because a compiler is unable to vectorize it properly.
|
21
|
+
|
22
|
+
template <typename C, uint32_t NBUCKETS, uint32_t N>
|
23
|
+
struct HeapWithBuckets {
|
24
|
+
// this case was not implemented yet.
|
25
|
+
};
|
26
|
+
|
27
|
+
template <uint32_t NBUCKETS, uint32_t N>
|
28
|
+
struct HeapWithBuckets<CMax<float, int>, NBUCKETS, N> {
|
29
|
+
static void addn(
|
30
|
+
// number of elements
|
31
|
+
const uint32_t n,
|
32
|
+
// distances. It is assumed to have n elements.
|
33
|
+
const float* const __restrict distances,
|
34
|
+
// number of best elements to keep
|
35
|
+
const uint32_t k,
|
36
|
+
// output distances
|
37
|
+
float* const __restrict bh_val,
|
38
|
+
// output indices, each being within [0, n) range
|
39
|
+
int32_t* const __restrict bh_ids) {
|
40
|
+
// forward a call to bs_addn with 1 beam
|
41
|
+
bs_addn(1, n, distances, k, bh_val, bh_ids);
|
42
|
+
}
|
43
|
+
|
44
|
+
static void bs_addn(
|
45
|
+
// beam_size parameter of Beam Search algorithm
|
46
|
+
const uint32_t beam_size,
|
47
|
+
// number of elements per beam
|
48
|
+
const uint32_t n_per_beam,
|
49
|
+
// distances. It is assumed to have (n_per_beam * beam_size)
|
50
|
+
// elements.
|
51
|
+
const float* const __restrict distances,
|
52
|
+
// number of best elements to keep
|
53
|
+
const uint32_t k,
|
54
|
+
// output distances
|
55
|
+
float* const __restrict bh_val,
|
56
|
+
// output indices, each being within [0, n_per_beam * beam_size)
|
57
|
+
// range
|
58
|
+
int32_t* const __restrict bh_ids) {
|
59
|
+
// // Basically, the function runs beam_size iterations.
|
60
|
+
// // Every iteration NBUCKETS * N elements are added to a regular heap.
|
61
|
+
// // So, maximum number of added elements is beam_size * NBUCKETS * N.
|
62
|
+
// // This number is expected to be less or equal than k.
|
63
|
+
// FAISS_THROW_IF_NOT_FMT(
|
64
|
+
// beam_size * NBUCKETS * N >= k,
|
65
|
+
// "Cannot pick %d elements, only %d. "
|
66
|
+
// "Check the function and template arguments values.",
|
67
|
+
// k,
|
68
|
+
// beam_size * NBUCKETS * N);
|
69
|
+
|
70
|
+
using C = CMax<float, int>;
|
71
|
+
|
72
|
+
// main loop
|
73
|
+
for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
|
74
|
+
float min_distances_i[N][NBUCKETS];
|
75
|
+
int min_indices_i[N][NBUCKETS];
|
76
|
+
|
77
|
+
for (uint32_t p = 0; p < N; p++) {
|
78
|
+
for (uint32_t j = 0; j < NBUCKETS; j++) {
|
79
|
+
min_distances_i[p][j] = std::numeric_limits<float>::max();
|
80
|
+
min_indices_i[p][j] = 0;
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
|
85
|
+
|
86
|
+
// put the data into buckets
|
87
|
+
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
|
88
|
+
for (uint32_t j = 0; j < NBUCKETS; j++) {
|
89
|
+
const int index = j + ip + n_per_beam * beam_index;
|
90
|
+
const float distance = distances[index];
|
91
|
+
|
92
|
+
int index_candidate = index;
|
93
|
+
float distance_candidate = distance;
|
94
|
+
|
95
|
+
for (uint32_t p = 0; p < N; p++) {
|
96
|
+
if (distance_candidate < min_distances_i[p][j]) {
|
97
|
+
std::swap(
|
98
|
+
distance_candidate, min_distances_i[p][j]);
|
99
|
+
std::swap(index_candidate, min_indices_i[p][j]);
|
100
|
+
}
|
101
|
+
}
|
102
|
+
}
|
103
|
+
}
|
104
|
+
|
105
|
+
// merge every bucket into the regular heap
|
106
|
+
for (uint32_t p = 0; p < N; p++) {
|
107
|
+
for (uint32_t j = 0; j < NBUCKETS; j++) {
|
108
|
+
// this exact way is needed to maintain the order as if the
|
109
|
+
// input elements were pushed to the heap sequentially
|
110
|
+
|
111
|
+
if (C::cmp2(bh_val[0],
|
112
|
+
min_distances_i[p][j],
|
113
|
+
bh_ids[0],
|
114
|
+
min_indices_i[p][j])) {
|
115
|
+
heap_replace_top<C>(
|
116
|
+
k,
|
117
|
+
bh_val,
|
118
|
+
bh_ids,
|
119
|
+
min_distances_i[p][j],
|
120
|
+
min_indices_i[p][j]);
|
121
|
+
}
|
122
|
+
}
|
123
|
+
}
|
124
|
+
|
125
|
+
// process leftovers
|
126
|
+
for (uint32_t ip = nb; ip < n_per_beam; ip++) {
|
127
|
+
const int32_t index = ip + n_per_beam * beam_index;
|
128
|
+
const float value = distances[index];
|
129
|
+
|
130
|
+
if (C::cmp(bh_val[0], value)) {
|
131
|
+
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
|
132
|
+
}
|
133
|
+
}
|
134
|
+
}
|
135
|
+
}
|
136
|
+
};
|
137
|
+
|
138
|
+
} // namespace faiss
|
@@ -0,0 +1,34 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
/// Represents the mode of use of approximate top-k computations
|
11
|
+
/// that allows to trade accuracy vs speed. So, every options
|
12
|
+
/// besides EXACT_TOPK increases the speed.
|
13
|
+
///
|
14
|
+
/// B represents the number of buckets.
|
15
|
+
/// D is the number of min-k elements to track within every bucket.
|
16
|
+
///
|
17
|
+
/// Default option is EXACT_TOPK.
|
18
|
+
/// APPROX_TOPK_BUCKETS_B16_D2 is worth starting from, if you'd like
|
19
|
+
/// to experiment a bit.
|
20
|
+
///
|
21
|
+
/// It seems that only the limited number of combinations are
|
22
|
+
/// meaningful, because of the limited supply of SIMD registers.
|
23
|
+
/// Also, certain combinations, such as B32_D1 and B16_D1, were concluded
|
24
|
+
/// to be not very precise in benchmarks, so ones were not introduced.
|
25
|
+
///
|
26
|
+
/// TODO: Consider d-ary SIMD heap.
|
27
|
+
|
28
|
+
enum ApproxTopK_mode_t : int {
|
29
|
+
EXACT_TOPK = 0,
|
30
|
+
APPROX_TOPK_BUCKETS_B32_D2 = 1,
|
31
|
+
APPROX_TOPK_BUCKETS_B8_D3 = 2,
|
32
|
+
APPROX_TOPK_BUCKETS_B16_D2 = 3,
|
33
|
+
APPROX_TOPK_BUCKETS_B8_D2 = 4,
|
34
|
+
};
|