faiss 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +36 -33
- data/vendor/faiss/faiss/AutoTune.h +6 -3
- data/vendor/faiss/faiss/Clustering.cpp +16 -12
- data/vendor/faiss/faiss/Index.cpp +3 -4
- data/vendor/faiss/faiss/Index.h +3 -3
- data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
- data/vendor/faiss/faiss/IndexBinary.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
- data/vendor/faiss/faiss/IndexFlat.h +0 -51
- data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
- data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
- data/vendor/faiss/faiss/IndexIVF.h +22 -15
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
- data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
- data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
- data/vendor/faiss/faiss/IndexRefine.h +73 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
- data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
- data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
- data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
- data/vendor/faiss/faiss/impl/io.cpp +33 -2
- data/vendor/faiss/faiss/impl/io.h +7 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
- data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
- data/vendor/faiss/faiss/index_factory.cpp +112 -7
- data/vendor/faiss/faiss/index_io.h +1 -48
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
- data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
- data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
- data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
- data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
- data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
- data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
- data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
- data/vendor/faiss/faiss/utils/Heap.h +61 -50
- data/vendor/faiss/faiss/utils/distances.cpp +164 -319
- data/vendor/faiss/faiss/utils/distances.h +28 -20
- data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
- data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
- data/vendor/faiss/faiss/utils/hamming.h +2 -7
- data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
- data/vendor/faiss/faiss/utils/partitioning.h +69 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
- data/vendor/faiss/faiss/utils/simdlib.h +31 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
- metadata +43 -141
- data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
- data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
- data/vendor/faiss/c_api/AutoTune_c.h +0 -66
- data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
- data/vendor/faiss/c_api/Clustering_c.h +0 -123
- data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
- data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
- data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
- data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
- data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
- data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
- data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
- data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
- data/vendor/faiss/c_api/IndexShards_c.h +0 -39
- data/vendor/faiss/c_api/Index_c.cpp +0 -105
- data/vendor/faiss/c_api/Index_c.h +0 -183
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
- data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
- data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
- data/vendor/faiss/c_api/clone_index_c.h +0 -32
- data/vendor/faiss/c_api/error_c.h +0 -42
- data/vendor/faiss/c_api/error_impl.cpp +0 -27
- data/vendor/faiss/c_api/error_impl.h +0 -16
- data/vendor/faiss/c_api/faiss_c.h +0 -58
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
- data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
- data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
- data/vendor/faiss/c_api/index_factory_c.h +0 -30
- data/vendor/faiss/c_api/index_io_c.cpp +0 -42
- data/vendor/faiss/c_api/index_io_c.h +0 -50
- data/vendor/faiss/c_api/macros_impl.h +0 -110
- data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
- data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
- data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
- data/vendor/faiss/misc/test_blas.cpp +0 -87
- data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
- data/vendor/faiss/tests/test_merge.cpp +0 -260
- data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
- data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
- data/vendor/faiss/tests/test_params_override.cpp +0 -236
- data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
- data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
- data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
- data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,141 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <cstdint>
|
12
|
+
#include <cstdlib>
|
13
|
+
#include <cassert>
|
14
|
+
#include <cstring>
|
15
|
+
|
16
|
+
#include <algorithm>
|
17
|
+
|
18
|
+
#include <faiss/impl/platform_macros.h>
|
19
|
+
|
20
|
+
namespace faiss {
|
21
|
+
|
22
|
+
template<int A=32>
|
23
|
+
inline bool is_aligned_pointer(const void* x)
|
24
|
+
{
|
25
|
+
size_t xi = (size_t)x;
|
26
|
+
return xi % A == 0;
|
27
|
+
}
|
28
|
+
|
29
|
+
// class that manages suitably aligned arrays for SIMD
|
30
|
+
// T should be a POV type. The default alignment is 32 for AVX
|
31
|
+
template<class T, int A=32>
|
32
|
+
struct AlignedTableTightAlloc {
|
33
|
+
T * ptr;
|
34
|
+
size_t numel;
|
35
|
+
|
36
|
+
AlignedTableTightAlloc(): ptr(nullptr), numel(0)
|
37
|
+
{ }
|
38
|
+
|
39
|
+
explicit AlignedTableTightAlloc(size_t n): ptr(nullptr), numel(0)
|
40
|
+
{ resize(n); }
|
41
|
+
|
42
|
+
size_t itemsize() const {return sizeof(T); }
|
43
|
+
|
44
|
+
void resize(size_t n) {
|
45
|
+
if (numel == n) {
|
46
|
+
return;
|
47
|
+
}
|
48
|
+
T * new_ptr;
|
49
|
+
if (n > 0) {
|
50
|
+
int ret = posix_memalign((void**)&new_ptr, A, n * sizeof(T));
|
51
|
+
if (ret != 0) {
|
52
|
+
throw std::bad_alloc();
|
53
|
+
}
|
54
|
+
if (numel > 0) {
|
55
|
+
memcpy(new_ptr, ptr, sizeof(T) * std::min(numel, n));
|
56
|
+
}
|
57
|
+
} else {
|
58
|
+
new_ptr = nullptr;
|
59
|
+
}
|
60
|
+
numel = n;
|
61
|
+
posix_memalign_free(ptr);
|
62
|
+
ptr = new_ptr;
|
63
|
+
}
|
64
|
+
|
65
|
+
void clear() {memset(ptr, 0, nbytes()); }
|
66
|
+
size_t size() const {return numel; }
|
67
|
+
size_t nbytes() const {return numel * sizeof(T); }
|
68
|
+
|
69
|
+
T * get() {return ptr; }
|
70
|
+
const T * get() const {return ptr; }
|
71
|
+
T * data() {return ptr; }
|
72
|
+
const T * data() const {return ptr; }
|
73
|
+
T & operator [] (size_t i) {return ptr[i]; }
|
74
|
+
T operator [] (size_t i) const {return ptr[i]; }
|
75
|
+
|
76
|
+
~AlignedTableTightAlloc() {posix_memalign_free(ptr); }
|
77
|
+
|
78
|
+
AlignedTableTightAlloc<T, A> & operator =
|
79
|
+
(const AlignedTableTightAlloc<T, A> & other) {
|
80
|
+
resize(other.numel);
|
81
|
+
memcpy(ptr, other.ptr, sizeof(T) * numel);
|
82
|
+
return *this;
|
83
|
+
}
|
84
|
+
|
85
|
+
AlignedTableTightAlloc(const AlignedTableTightAlloc<T, A> & other) {
|
86
|
+
*this = other;
|
87
|
+
}
|
88
|
+
|
89
|
+
};
|
90
|
+
|
91
|
+
// same as AlignedTableTightAlloc, but with geometric re-allocation
|
92
|
+
template<class T, int A=32>
|
93
|
+
struct AlignedTable {
|
94
|
+
AlignedTableTightAlloc<T, A> tab;
|
95
|
+
size_t numel = 0;
|
96
|
+
|
97
|
+
static size_t round_capacity(size_t n) {
|
98
|
+
if (n == 0) {
|
99
|
+
return 0;
|
100
|
+
}
|
101
|
+
if (n < 8 * A) {
|
102
|
+
return 8 * A;
|
103
|
+
}
|
104
|
+
size_t capacity = 8 * A;
|
105
|
+
while (capacity < n) {
|
106
|
+
capacity *= 2;
|
107
|
+
}
|
108
|
+
return capacity;
|
109
|
+
}
|
110
|
+
|
111
|
+
AlignedTable() {}
|
112
|
+
|
113
|
+
explicit AlignedTable(size_t n):
|
114
|
+
tab(round_capacity(n)),
|
115
|
+
numel(n)
|
116
|
+
{ }
|
117
|
+
|
118
|
+
size_t itemsize() const {return sizeof(T); }
|
119
|
+
|
120
|
+
void resize(size_t n) {
|
121
|
+
tab.resize(round_capacity(n));
|
122
|
+
numel = n;
|
123
|
+
}
|
124
|
+
|
125
|
+
void clear() { tab.clear(); }
|
126
|
+
size_t size() const {return numel; }
|
127
|
+
size_t nbytes() const {return numel * sizeof(T); }
|
128
|
+
|
129
|
+
T * get() {return tab.get(); }
|
130
|
+
const T * get() const {return tab.get(); }
|
131
|
+
T * data() {return tab.get(); }
|
132
|
+
const T * data() const {return tab.get(); }
|
133
|
+
T & operator [] (size_t i) {return tab.ptr[i]; }
|
134
|
+
T operator [] (size_t i) const {return tab.ptr[i]; }
|
135
|
+
|
136
|
+
// assign and copy constructor should work as expected
|
137
|
+
|
138
|
+
};
|
139
|
+
|
140
|
+
|
141
|
+
} // namespace faiss
|
@@ -46,8 +46,7 @@ void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
|
|
46
46
|
for (size_t j = 0; j < nj; j++) {
|
47
47
|
T ip = ip_line [j];
|
48
48
|
if (C::cmp(simi[0], ip)) {
|
49
|
-
|
50
|
-
heap_push<C> (k, simi, idxi, ip, j + j0);
|
49
|
+
heap_replace_top<C> (k, simi, idxi, ip, j + j0);
|
51
50
|
}
|
52
51
|
}
|
53
52
|
}
|
@@ -74,8 +73,7 @@ void HeapArray<C>::addn_with_ids (
|
|
74
73
|
for (size_t j = 0; j < nj; j++) {
|
75
74
|
T ip = ip_line [j];
|
76
75
|
if (C::cmp(simi[0], ip)) {
|
77
|
-
|
78
|
-
heap_push<C> (k, simi, idxi, ip, id_line [j]);
|
76
|
+
heap_replace_top<C> (k, simi, idxi, ip, id_line [j]);
|
79
77
|
}
|
80
78
|
}
|
81
79
|
}
|
@@ -5,16 +5,18 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
8
|
|
10
9
|
/*
|
11
|
-
* C++ support for heaps. The set of functions is tailored for
|
12
|
-
*
|
10
|
+
* C++ support for heaps. The set of functions is tailored for efficient
|
11
|
+
* similarity search.
|
13
12
|
*
|
14
|
-
* There is no specific object for a heap, and the functions that
|
15
|
-
*
|
16
|
-
*
|
13
|
+
* There is no specific object for a heap, and the functions that operate on a
|
14
|
+
* single heap are inlined, because heaps are often small. More complex
|
15
|
+
* functions are implemented in Heaps.cpp
|
17
16
|
*
|
17
|
+
* All heap functions rely on a C template class that define the type of the
|
18
|
+
* keys and values and their ordering (increasing with CMax and decreasing with
|
19
|
+
* Cmin). The C types are defined in ordered_key_value.h
|
18
20
|
*/
|
19
21
|
|
20
22
|
|
@@ -31,51 +33,12 @@
|
|
31
33
|
|
32
34
|
#include <limits>
|
33
35
|
|
36
|
+
#include <faiss/utils/ordered_key_value.h>
|
34
37
|
|
35
38
|
namespace faiss {
|
36
39
|
|
37
|
-
/*******************************************************************
|
38
|
-
* C object: uniform handling of min and max heap
|
39
|
-
*******************************************************************/
|
40
|
-
|
41
|
-
/** The C object gives the type T of the values in the heap, the type
|
42
|
-
* of the keys, TI and the comparison that is done: > for the minheap
|
43
|
-
* and < for the maxheap. The neutral value will always be dropped in
|
44
|
-
* favor of any other value in the heap.
|
45
|
-
*/
|
46
|
-
|
47
|
-
template <typename T_, typename TI_>
|
48
|
-
struct CMax;
|
49
|
-
|
50
|
-
// traits of minheaps = heaps where the minimum value is stored on top
|
51
|
-
// useful to find the *max* values of an array
|
52
|
-
template <typename T_, typename TI_>
|
53
|
-
struct CMin {
|
54
|
-
typedef T_ T;
|
55
|
-
typedef TI_ TI;
|
56
|
-
typedef CMax<T_, TI_> Crev;
|
57
|
-
inline static bool cmp (T a, T b) {
|
58
|
-
return a < b;
|
59
|
-
}
|
60
|
-
inline static T neutral () {
|
61
|
-
return std::numeric_limits<T>::lowest();
|
62
|
-
}
|
63
|
-
};
|
64
40
|
|
65
41
|
|
66
|
-
template <typename T_, typename TI_>
|
67
|
-
struct CMax {
|
68
|
-
typedef T_ T;
|
69
|
-
typedef TI_ TI;
|
70
|
-
typedef CMin<T_, TI_> Crev;
|
71
|
-
inline static bool cmp (T a, T b) {
|
72
|
-
return a > b;
|
73
|
-
}
|
74
|
-
inline static T neutral () {
|
75
|
-
return std::numeric_limits<T>::max();
|
76
|
-
}
|
77
|
-
};
|
78
|
-
|
79
42
|
|
80
43
|
/*******************************************************************
|
81
44
|
* Basic heap ops: push and pop
|
@@ -142,6 +105,43 @@ void heap_push (size_t k,
|
|
142
105
|
|
143
106
|
|
144
107
|
|
108
|
+
/** Replace the top element from the heap defined by bh_val[0..k-1] and
|
109
|
+
* bh_ids[0..k-1].
|
110
|
+
*/
|
111
|
+
template <class C> inline
|
112
|
+
void heap_replace_top (size_t k,
|
113
|
+
typename C::T * bh_val, typename C::TI * bh_ids,
|
114
|
+
typename C::T val, typename C::TI ids)
|
115
|
+
{
|
116
|
+
bh_val--; /* Use 1-based indexing for easier node->child translation */
|
117
|
+
bh_ids--;
|
118
|
+
size_t i = 1, i1, i2;
|
119
|
+
while (1) {
|
120
|
+
i1 = i << 1;
|
121
|
+
i2 = i1 + 1;
|
122
|
+
if (i1 > k)
|
123
|
+
break;
|
124
|
+
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
|
125
|
+
if (C::cmp(val, bh_val[i1]))
|
126
|
+
break;
|
127
|
+
bh_val[i] = bh_val[i1];
|
128
|
+
bh_ids[i] = bh_ids[i1];
|
129
|
+
i = i1;
|
130
|
+
}
|
131
|
+
else {
|
132
|
+
if (C::cmp(val, bh_val[i2]))
|
133
|
+
break;
|
134
|
+
bh_val[i] = bh_val[i2];
|
135
|
+
bh_ids[i] = bh_ids[i2];
|
136
|
+
i = i2;
|
137
|
+
}
|
138
|
+
}
|
139
|
+
bh_val[i] = val;
|
140
|
+
bh_ids[i] = ids;
|
141
|
+
}
|
142
|
+
|
143
|
+
|
144
|
+
|
145
145
|
/* Partial instanciation for heaps with TI = int64_t */
|
146
146
|
|
147
147
|
template <typename T> inline
|
@@ -158,6 +158,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
|
158
158
|
}
|
159
159
|
|
160
160
|
|
161
|
+
template <typename T> inline
|
162
|
+
void minheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
163
|
+
{
|
164
|
+
heap_replace_top<CMin<T, int64_t> > (k, bh_val, bh_ids, val, ids);
|
165
|
+
}
|
166
|
+
|
167
|
+
|
161
168
|
template <typename T> inline
|
162
169
|
void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
|
163
170
|
{
|
@@ -172,6 +179,12 @@ void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
|
172
179
|
}
|
173
180
|
|
174
181
|
|
182
|
+
template <typename T> inline
|
183
|
+
void maxheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
184
|
+
{
|
185
|
+
heap_replace_top<CMax<T, int64_t> > (k, bh_val, bh_ids, val, ids);
|
186
|
+
}
|
187
|
+
|
175
188
|
|
176
189
|
/*******************************************************************
|
177
190
|
* Heap initialization
|
@@ -249,15 +262,13 @@ void heap_addn (size_t k,
|
|
249
262
|
if (ids)
|
250
263
|
for (i = 0; i < n; i++) {
|
251
264
|
if (C::cmp (bh_val[0], x[i])) {
|
252
|
-
|
253
|
-
heap_push<C> (k, bh_val, bh_ids, x[i], ids[i]);
|
265
|
+
heap_replace_top<C> (k, bh_val, bh_ids, x[i], ids[i]);
|
254
266
|
}
|
255
267
|
}
|
256
268
|
else
|
257
269
|
for (i = 0; i < n; i++) {
|
258
270
|
if (C::cmp (bh_val[0], x[i])) {
|
259
|
-
|
260
|
-
heap_push<C> (k, bh_val, bh_ids, x[i], i);
|
271
|
+
heap_replace_top<C> (k, bh_val, bh_ids, x[i], i);
|
261
272
|
}
|
262
273
|
}
|
263
274
|
}
|
@@ -19,6 +19,7 @@
|
|
19
19
|
|
20
20
|
#include <faiss/impl/AuxIndexStructures.h>
|
21
21
|
#include <faiss/impl/FaissAssert.h>
|
22
|
+
#include <faiss/impl/ResultHandler.h>
|
22
23
|
|
23
24
|
|
24
25
|
|
@@ -36,14 +37,6 @@ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
|
36
37
|
FINTEGER *lda, const float *b, FINTEGER *
|
37
38
|
ldb, float *beta, float *c, FINTEGER *ldc);
|
38
39
|
|
39
|
-
/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
|
40
|
-
|
41
|
-
int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
|
42
|
-
float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
|
43
|
-
|
44
|
-
int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
|
45
|
-
const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
|
46
|
-
float *beta, float *y, FINTEGER *incy);
|
47
40
|
|
48
41
|
}
|
49
42
|
|
@@ -58,34 +51,6 @@ namespace faiss {
|
|
58
51
|
|
59
52
|
|
60
53
|
|
61
|
-
/* Compute the inner product between a vector x and
|
62
|
-
a set of ny vectors y.
|
63
|
-
These functions are not intended to replace BLAS matrix-matrix, as they
|
64
|
-
would be significantly less efficient in this case. */
|
65
|
-
void fvec_inner_products_ny (float * ip,
|
66
|
-
const float * x,
|
67
|
-
const float * y,
|
68
|
-
size_t d, size_t ny)
|
69
|
-
{
|
70
|
-
// Not sure which one is fastest
|
71
|
-
#if 0
|
72
|
-
{
|
73
|
-
FINTEGER di = d;
|
74
|
-
FINTEGER nyi = ny;
|
75
|
-
float one = 1.0, zero = 0.0;
|
76
|
-
FINTEGER onei = 1;
|
77
|
-
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
78
|
-
}
|
79
|
-
#endif
|
80
|
-
for (size_t i = 0; i < ny; i++) {
|
81
|
-
ip[i] = fvec_inner_product (x, y, d);
|
82
|
-
y += d;
|
83
|
-
}
|
84
|
-
}
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
54
|
|
90
55
|
/* Compute the L2 norm of a set of nx vectors */
|
91
56
|
void fvec_norms_L2 (float * __restrict nr,
|
@@ -142,109 +107,112 @@ void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
|
142
107
|
* KNN functions
|
143
108
|
***************************************************************************/
|
144
109
|
|
110
|
+
namespace {
|
111
|
+
|
145
112
|
|
146
113
|
|
147
114
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
115
|
+
template<class ResultHandler>
|
116
|
+
void exhaustive_inner_product_seq (
|
117
|
+
const float * x,
|
118
|
+
const float * y,
|
119
|
+
size_t d, size_t nx, size_t ny,
|
120
|
+
ResultHandler &res)
|
152
121
|
{
|
153
|
-
size_t k = res->k;
|
154
122
|
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
155
123
|
|
156
124
|
check_period *= omp_get_max_threads();
|
157
125
|
|
126
|
+
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
127
|
+
|
158
128
|
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
159
129
|
size_t i1 = std::min(i0 + check_period, nx);
|
160
130
|
|
161
|
-
#pragma omp parallel
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
minheap_heapify (k, simi, idxi);
|
131
|
+
#pragma omp parallel
|
132
|
+
{
|
133
|
+
SingleResultHandler resi(res);
|
134
|
+
#pragma omp for
|
135
|
+
for (int64_t i = i0; i < i1; i++) {
|
136
|
+
const float * x_i = x + i * d;
|
137
|
+
const float * y_j = y;
|
170
138
|
|
171
|
-
|
172
|
-
float ip = fvec_inner_product (x_i, y_j, d);
|
139
|
+
resi.begin(i);
|
173
140
|
|
174
|
-
|
175
|
-
|
176
|
-
|
141
|
+
for (size_t j = 0; j < ny; j++) {
|
142
|
+
float ip = fvec_inner_product (x_i, y_j, d);
|
143
|
+
resi.add_result(ip, j);
|
144
|
+
y_j += d;
|
177
145
|
}
|
178
|
-
|
146
|
+
resi.end();
|
179
147
|
}
|
180
|
-
minheap_reorder (k, simi, idxi);
|
181
148
|
}
|
182
149
|
InterruptCallback::check ();
|
183
150
|
}
|
184
151
|
|
185
152
|
}
|
186
153
|
|
187
|
-
|
154
|
+
template<class ResultHandler>
|
155
|
+
void exhaustive_L2sqr_seq (
|
188
156
|
const float * x,
|
189
157
|
const float * y,
|
190
158
|
size_t d, size_t nx, size_t ny,
|
191
|
-
|
159
|
+
ResultHandler & res)
|
192
160
|
{
|
193
|
-
size_t k = res->k;
|
194
161
|
|
195
162
|
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
196
163
|
check_period *= omp_get_max_threads();
|
164
|
+
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
197
165
|
|
198
166
|
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
199
167
|
size_t i1 = std::min(i0 + check_period, nx);
|
200
168
|
|
201
|
-
#pragma omp parallel
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
if (disij < simi[0]) {
|
214
|
-
maxheap_pop (k, simi, idxi);
|
215
|
-
maxheap_push (k, simi, idxi, disij, j);
|
169
|
+
#pragma omp parallel
|
170
|
+
{
|
171
|
+
SingleResultHandler resi(res);
|
172
|
+
#pragma omp for
|
173
|
+
for (int64_t i = i0; i < i1; i++) {
|
174
|
+
const float * x_i = x + i * d;
|
175
|
+
const float * y_j = y;
|
176
|
+
resi.begin(i);
|
177
|
+
for (size_t j = 0; j < ny; j++) {
|
178
|
+
float disij = fvec_L2sqr (x_i, y_j, d);
|
179
|
+
resi.add_result(disij, j);
|
180
|
+
y_j += d;
|
216
181
|
}
|
217
|
-
|
182
|
+
resi.end();
|
218
183
|
}
|
219
|
-
maxheap_reorder (k, simi, idxi);
|
220
184
|
}
|
221
185
|
InterruptCallback::check ();
|
222
186
|
}
|
223
187
|
|
224
|
-
}
|
188
|
+
};
|
189
|
+
|
190
|
+
|
191
|
+
|
225
192
|
|
226
193
|
|
227
194
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
228
|
-
|
195
|
+
template<class ResultHandler>
|
196
|
+
void exhaustive_inner_product_blas (
|
229
197
|
const float * x,
|
230
198
|
const float * y,
|
231
199
|
size_t d, size_t nx, size_t ny,
|
232
|
-
|
200
|
+
ResultHandler & res)
|
233
201
|
{
|
234
|
-
res->heapify ();
|
235
|
-
|
236
202
|
// BLAS does not like empty matrices
|
237
203
|
if (nx == 0 || ny == 0) return;
|
238
204
|
|
239
205
|
/* block sizes */
|
240
|
-
const size_t bs_x =
|
241
|
-
|
206
|
+
const size_t bs_x = distance_compute_blas_query_bs;
|
207
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
242
208
|
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
243
209
|
|
244
210
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
245
211
|
size_t i1 = i0 + bs_x;
|
246
212
|
if(i1 > nx) i1 = nx;
|
247
213
|
|
214
|
+
res.begin_multiple(i0, i1);
|
215
|
+
|
248
216
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
249
217
|
size_t j1 = j0 + bs_y;
|
250
218
|
if (j1 > ny) j1 = ny;
|
@@ -258,46 +226,54 @@ static void knn_inner_product_blas (
|
|
258
226
|
ip_block.get(), &nyi);
|
259
227
|
}
|
260
228
|
|
261
|
-
|
262
|
-
|
229
|
+
res.add_results(j0, j1, ip_block.get());
|
230
|
+
|
263
231
|
}
|
232
|
+
res.end_multiple();
|
264
233
|
InterruptCallback::check ();
|
234
|
+
|
265
235
|
}
|
266
|
-
res->reorder ();
|
267
236
|
}
|
268
237
|
|
238
|
+
|
239
|
+
|
240
|
+
|
269
241
|
// distance correction is an operator that can be applied to transform
|
270
242
|
// the distances
|
271
|
-
template<class
|
272
|
-
|
243
|
+
template<class ResultHandler>
|
244
|
+
void exhaustive_L2sqr_blas (
|
245
|
+
const float * x,
|
273
246
|
const float * y,
|
274
247
|
size_t d, size_t nx, size_t ny,
|
275
|
-
|
276
|
-
const
|
248
|
+
ResultHandler & res,
|
249
|
+
const float *y_norms = nullptr)
|
277
250
|
{
|
278
|
-
res->heapify ();
|
279
|
-
|
280
251
|
// BLAS does not like empty matrices
|
281
252
|
if (nx == 0 || ny == 0) return;
|
282
253
|
|
283
|
-
size_t k = res->k;
|
284
|
-
|
285
254
|
/* block sizes */
|
286
|
-
const size_t bs_x =
|
255
|
+
const size_t bs_x = distance_compute_blas_query_bs;
|
256
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
287
257
|
// const size_t bs_x = 16, bs_y = 16;
|
288
|
-
float
|
289
|
-
float
|
290
|
-
float
|
291
|
-
ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
|
258
|
+
std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
|
259
|
+
std::unique_ptr<float []> x_norms(new float[nx]);
|
260
|
+
std::unique_ptr<float []> del2;
|
292
261
|
|
293
|
-
fvec_norms_L2sqr (x_norms, x, d, nx);
|
294
|
-
fvec_norms_L2sqr (y_norms, y, d, ny);
|
262
|
+
fvec_norms_L2sqr (x_norms.get(), x, d, nx);
|
295
263
|
|
264
|
+
if (!y_norms) {
|
265
|
+
float *y_norms2 = new float[ny];
|
266
|
+
del2.reset(y_norms2);
|
267
|
+
fvec_norms_L2sqr (y_norms2, y, d, ny);
|
268
|
+
y_norms = y_norms2;
|
269
|
+
}
|
296
270
|
|
297
271
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
298
272
|
size_t i1 = i0 + bs_x;
|
299
273
|
if(i1 > nx) i1 = nx;
|
300
274
|
|
275
|
+
res.begin_multiple(i0, i1);
|
276
|
+
|
301
277
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
302
278
|
size_t j1 = j0 + bs_y;
|
303
279
|
if (j1 > ny) j1 = ny;
|
@@ -308,42 +284,34 @@ static void knn_L2sqr_blas (const float * x,
|
|
308
284
|
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
309
285
|
y + j0 * d, &di,
|
310
286
|
x + i0 * d, &di, &zero,
|
311
|
-
ip_block, &nyi);
|
287
|
+
ip_block.get(), &nyi);
|
312
288
|
}
|
313
289
|
|
314
|
-
/* collect minima */
|
315
|
-
#pragma omp parallel for
|
316
290
|
for (int64_t i = i0; i < i1; i++) {
|
317
|
-
float *
|
318
|
-
int64_t * __restrict idxi = res->get_ids (i);
|
319
|
-
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
291
|
+
float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
320
292
|
|
321
293
|
for (size_t j = j0; j < j1; j++) {
|
322
|
-
float ip = *ip_line
|
294
|
+
float ip = *ip_line;
|
323
295
|
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
324
296
|
|
325
297
|
// negative values can occur for identical vectors
|
326
298
|
// due to roundoff errors
|
327
299
|
if (dis < 0) dis = 0;
|
328
300
|
|
329
|
-
|
330
|
-
|
331
|
-
if (dis < simi[0]) {
|
332
|
-
maxheap_pop (k, simi, idxi);
|
333
|
-
maxheap_push (k, simi, idxi, dis, j);
|
334
|
-
}
|
301
|
+
*ip_line = dis;
|
302
|
+
ip_line++;
|
335
303
|
}
|
336
304
|
}
|
305
|
+
res.add_results(j0, j1, ip_block.get());
|
337
306
|
}
|
307
|
+
res.end_multiple();
|
338
308
|
InterruptCallback::check ();
|
339
309
|
}
|
340
|
-
res->reorder ();
|
341
|
-
|
342
310
|
}
|
343
311
|
|
344
312
|
|
345
313
|
|
346
|
-
|
314
|
+
} // anonymous namespace
|
347
315
|
|
348
316
|
|
349
317
|
|
@@ -354,58 +322,103 @@ static void knn_L2sqr_blas (const float * x,
|
|
354
322
|
*******************************************************/
|
355
323
|
|
356
324
|
int distance_compute_blas_threshold = 20;
|
325
|
+
int distance_compute_blas_query_bs = 4096;
|
326
|
+
int distance_compute_blas_database_bs = 1024;
|
327
|
+
int distance_compute_min_k_reservoir = 100;
|
357
328
|
|
358
329
|
void knn_inner_product (const float * x,
|
359
330
|
const float * y,
|
360
331
|
size_t d, size_t nx, size_t ny,
|
361
|
-
float_minheap_array_t *
|
332
|
+
float_minheap_array_t * ha)
|
362
333
|
{
|
363
|
-
if (
|
364
|
-
|
334
|
+
if (ha->k < distance_compute_min_k_reservoir) {
|
335
|
+
HeapResultHandler<CMin<float, int64_t>> res(
|
336
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
337
|
+
if (nx < distance_compute_blas_threshold) {
|
338
|
+
exhaustive_inner_product_seq (x, y, d, nx, ny, res);
|
339
|
+
} else {
|
340
|
+
exhaustive_inner_product_blas (x, y, d, nx, ny, res);
|
341
|
+
}
|
365
342
|
} else {
|
366
|
-
|
343
|
+
ReservoirResultHandler<CMin<float, int64_t>> res(
|
344
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
345
|
+
if (nx < distance_compute_blas_threshold) {
|
346
|
+
exhaustive_inner_product_seq (x, y, d, nx, ny, res);
|
347
|
+
} else {
|
348
|
+
exhaustive_inner_product_blas (x, y, d, nx, ny, res);
|
349
|
+
}
|
367
350
|
}
|
368
351
|
}
|
369
352
|
|
370
353
|
|
371
354
|
|
372
|
-
|
373
|
-
|
374
|
-
|
355
|
+
|
356
|
+
void knn_L2sqr (
|
357
|
+
const float * x,
|
358
|
+
const float * y,
|
359
|
+
size_t d, size_t nx, size_t ny,
|
360
|
+
float_maxheap_array_t * ha,
|
361
|
+
const float *y_norm2
|
362
|
+
) {
|
363
|
+
|
364
|
+
if (ha->k < distance_compute_min_k_reservoir) {
|
365
|
+
HeapResultHandler<CMax<float, int64_t>> res(
|
366
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
367
|
+
|
368
|
+
if (nx < distance_compute_blas_threshold) {
|
369
|
+
exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
|
370
|
+
} else {
|
371
|
+
exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
|
372
|
+
}
|
373
|
+
} else {
|
374
|
+
ReservoirResultHandler<CMax<float, int64_t>> res(
|
375
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
376
|
+
if (nx < distance_compute_blas_threshold) {
|
377
|
+
exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
|
378
|
+
} else {
|
379
|
+
exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
|
380
|
+
}
|
375
381
|
}
|
376
|
-
}
|
382
|
+
}
|
377
383
|
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
384
|
+
|
385
|
+
/***************************************************************************
|
386
|
+
* Range search
|
387
|
+
***************************************************************************/
|
388
|
+
|
389
|
+
|
390
|
+
|
391
|
+
|
392
|
+
void range_search_L2sqr (
|
393
|
+
const float * x,
|
394
|
+
const float * y,
|
395
|
+
size_t d, size_t nx, size_t ny,
|
396
|
+
float radius,
|
397
|
+
RangeSearchResult *res)
|
382
398
|
{
|
399
|
+
RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
|
383
400
|
if (nx < distance_compute_blas_threshold) {
|
384
|
-
|
401
|
+
exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
|
385
402
|
} else {
|
386
|
-
|
387
|
-
knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
|
403
|
+
exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
|
388
404
|
}
|
389
405
|
}
|
390
406
|
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
void knn_L2sqr_base_shift (
|
399
|
-
const float * x,
|
400
|
-
const float * y,
|
401
|
-
size_t d, size_t nx, size_t ny,
|
402
|
-
float_maxheap_array_t * res,
|
403
|
-
const float *base_shift)
|
407
|
+
void range_search_inner_product (
|
408
|
+
const float * x,
|
409
|
+
const float * y,
|
410
|
+
size_t d, size_t nx, size_t ny,
|
411
|
+
float radius,
|
412
|
+
RangeSearchResult *res)
|
404
413
|
{
|
405
|
-
BaseShiftDistanceCorrection corr = {base_shift};
|
406
|
-
knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
|
407
|
-
}
|
408
414
|
|
415
|
+
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
|
416
|
+
if (nx < distance_compute_blas_threshold) {
|
417
|
+
exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
|
418
|
+
} else {
|
419
|
+
exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
|
420
|
+
}
|
421
|
+
}
|
409
422
|
|
410
423
|
|
411
424
|
/***************************************************************************
|
@@ -509,8 +522,7 @@ void knn_inner_products_by_idx (const float * x,
|
|
509
522
|
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
|
510
523
|
|
511
524
|
if (ip > simi[0]) {
|
512
|
-
|
513
|
-
minheap_push (k, simi, idxi, ip, idsi[j]);
|
525
|
+
minheap_replace_top (k, simi, idxi, ip, idsi[j]);
|
514
526
|
}
|
515
527
|
}
|
516
528
|
minheap_reorder (k, simi, idxi);
|
@@ -537,8 +549,7 @@ void knn_L2sqr_by_idx (const float * x,
|
|
537
549
|
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
|
538
550
|
|
539
551
|
if (disij < simi[0]) {
|
540
|
-
|
541
|
-
maxheap_push (k, simi, idxi, disij, idsi[j]);
|
552
|
+
maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
|
542
553
|
}
|
543
554
|
}
|
544
555
|
maxheap_reorder (res->k, simi, idxi);
|
@@ -550,172 +561,6 @@ void knn_L2sqr_by_idx (const float * x,
|
|
550
561
|
|
551
562
|
|
552
563
|
|
553
|
-
/***************************************************************************
|
554
|
-
* Range search
|
555
|
-
***************************************************************************/
|
556
|
-
|
557
|
-
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
558
|
-
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
559
|
-
*/
|
560
|
-
template <bool compute_l2>
|
561
|
-
static void range_search_blas (
|
562
|
-
const float * x,
|
563
|
-
const float * y,
|
564
|
-
size_t d, size_t nx, size_t ny,
|
565
|
-
float radius,
|
566
|
-
RangeSearchResult *result)
|
567
|
-
{
|
568
|
-
|
569
|
-
// BLAS does not like empty matrices
|
570
|
-
if (nx == 0 || ny == 0) return;
|
571
|
-
|
572
|
-
/* block sizes */
|
573
|
-
const size_t bs_x = 4096, bs_y = 1024;
|
574
|
-
// const size_t bs_x = 16, bs_y = 16;
|
575
|
-
float *ip_block = new float[bs_x * bs_y];
|
576
|
-
ScopeDeleter<float> del0(ip_block);
|
577
|
-
|
578
|
-
float *x_norms = nullptr, *y_norms = nullptr;
|
579
|
-
ScopeDeleter<float> del1, del2;
|
580
|
-
if (compute_l2) {
|
581
|
-
x_norms = new float[nx];
|
582
|
-
del1.set (x_norms);
|
583
|
-
fvec_norms_L2sqr (x_norms, x, d, nx);
|
584
|
-
|
585
|
-
y_norms = new float[ny];
|
586
|
-
del2.set (y_norms);
|
587
|
-
fvec_norms_L2sqr (y_norms, y, d, ny);
|
588
|
-
}
|
589
|
-
|
590
|
-
std::vector <RangeSearchPartialResult *> partial_results;
|
591
|
-
|
592
|
-
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
593
|
-
size_t j1 = j0 + bs_y;
|
594
|
-
if (j1 > ny) j1 = ny;
|
595
|
-
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
|
596
|
-
partial_results.push_back (pres);
|
597
|
-
|
598
|
-
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
599
|
-
size_t i1 = i0 + bs_x;
|
600
|
-
if(i1 > nx) i1 = nx;
|
601
|
-
|
602
|
-
/* compute the actual dot products */
|
603
|
-
{
|
604
|
-
float one = 1, zero = 0;
|
605
|
-
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
606
|
-
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
607
|
-
y + j0 * d, &di,
|
608
|
-
x + i0 * d, &di, &zero,
|
609
|
-
ip_block, &nyi);
|
610
|
-
}
|
611
|
-
|
612
|
-
|
613
|
-
for (size_t i = i0; i < i1; i++) {
|
614
|
-
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
615
|
-
|
616
|
-
RangeQueryResult & qres = pres->new_result (i);
|
617
|
-
|
618
|
-
for (size_t j = j0; j < j1; j++) {
|
619
|
-
float ip = *ip_line++;
|
620
|
-
if (compute_l2) {
|
621
|
-
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
622
|
-
if (dis < radius) {
|
623
|
-
qres.add (dis, j);
|
624
|
-
}
|
625
|
-
} else {
|
626
|
-
if (ip > radius) {
|
627
|
-
qres.add (ip, j);
|
628
|
-
}
|
629
|
-
}
|
630
|
-
}
|
631
|
-
}
|
632
|
-
}
|
633
|
-
InterruptCallback::check ();
|
634
|
-
}
|
635
|
-
|
636
|
-
RangeSearchPartialResult::merge (partial_results);
|
637
|
-
}
|
638
|
-
|
639
|
-
|
640
|
-
template <bool compute_l2>
|
641
|
-
static void range_search_sse (const float * x,
|
642
|
-
const float * y,
|
643
|
-
size_t d, size_t nx, size_t ny,
|
644
|
-
float radius,
|
645
|
-
RangeSearchResult *res)
|
646
|
-
{
|
647
|
-
|
648
|
-
#pragma omp parallel
|
649
|
-
{
|
650
|
-
RangeSearchPartialResult pres (res);
|
651
|
-
|
652
|
-
#pragma omp for
|
653
|
-
for (int64_t i = 0; i < nx; i++) {
|
654
|
-
const float * x_ = x + i * d;
|
655
|
-
const float * y_ = y;
|
656
|
-
size_t j;
|
657
|
-
|
658
|
-
RangeQueryResult & qres = pres.new_result (i);
|
659
|
-
|
660
|
-
for (j = 0; j < ny; j++) {
|
661
|
-
if (compute_l2) {
|
662
|
-
float disij = fvec_L2sqr (x_, y_, d);
|
663
|
-
if (disij < radius) {
|
664
|
-
qres.add (disij, j);
|
665
|
-
}
|
666
|
-
} else {
|
667
|
-
float ip = fvec_inner_product (x_, y_, d);
|
668
|
-
if (ip > radius) {
|
669
|
-
qres.add (ip, j);
|
670
|
-
}
|
671
|
-
}
|
672
|
-
y_ += d;
|
673
|
-
}
|
674
|
-
|
675
|
-
}
|
676
|
-
pres.finalize ();
|
677
|
-
}
|
678
|
-
|
679
|
-
// check just at the end because the use case is typically just
|
680
|
-
// when the nb of queries is low.
|
681
|
-
InterruptCallback::check();
|
682
|
-
}
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
void range_search_L2sqr (
|
689
|
-
const float * x,
|
690
|
-
const float * y,
|
691
|
-
size_t d, size_t nx, size_t ny,
|
692
|
-
float radius,
|
693
|
-
RangeSearchResult *res)
|
694
|
-
{
|
695
|
-
|
696
|
-
if (nx < distance_compute_blas_threshold) {
|
697
|
-
range_search_sse<true> (x, y, d, nx, ny, radius, res);
|
698
|
-
} else {
|
699
|
-
range_search_blas<true> (x, y, d, nx, ny, radius, res);
|
700
|
-
}
|
701
|
-
}
|
702
|
-
|
703
|
-
void range_search_inner_product (
|
704
|
-
const float * x,
|
705
|
-
const float * y,
|
706
|
-
size_t d, size_t nx, size_t ny,
|
707
|
-
float radius,
|
708
|
-
RangeSearchResult *res)
|
709
|
-
{
|
710
|
-
|
711
|
-
if (nx < distance_compute_blas_threshold) {
|
712
|
-
range_search_sse<false> (x, y, d, nx, ny, radius, res);
|
713
|
-
} else {
|
714
|
-
range_search_blas<false> (x, y, d, nx, ny, radius, res);
|
715
|
-
}
|
716
|
-
}
|
717
|
-
|
718
|
-
|
719
564
|
void pairwise_L2sqr (int64_t d,
|
720
565
|
int64_t nq, const float *xq,
|
721
566
|
int64_t nb, const float *xb,
|