faiss 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +36 -33
- data/vendor/faiss/faiss/AutoTune.h +6 -3
- data/vendor/faiss/faiss/Clustering.cpp +16 -12
- data/vendor/faiss/faiss/Index.cpp +3 -4
- data/vendor/faiss/faiss/Index.h +3 -3
- data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
- data/vendor/faiss/faiss/IndexBinary.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
- data/vendor/faiss/faiss/IndexFlat.h +0 -51
- data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
- data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
- data/vendor/faiss/faiss/IndexIVF.h +22 -15
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
- data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
- data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
- data/vendor/faiss/faiss/IndexRefine.h +73 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
- data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
- data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
- data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
- data/vendor/faiss/faiss/impl/io.cpp +33 -2
- data/vendor/faiss/faiss/impl/io.h +7 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
- data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
- data/vendor/faiss/faiss/index_factory.cpp +112 -7
- data/vendor/faiss/faiss/index_io.h +1 -48
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
- data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
- data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
- data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
- data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
- data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
- data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
- data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
- data/vendor/faiss/faiss/utils/Heap.h +61 -50
- data/vendor/faiss/faiss/utils/distances.cpp +164 -319
- data/vendor/faiss/faiss/utils/distances.h +28 -20
- data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
- data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
- data/vendor/faiss/faiss/utils/hamming.h +2 -7
- data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
- data/vendor/faiss/faiss/utils/partitioning.h +69 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
- data/vendor/faiss/faiss/utils/simdlib.h +31 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
- metadata +43 -141
- data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
- data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
- data/vendor/faiss/c_api/AutoTune_c.h +0 -66
- data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
- data/vendor/faiss/c_api/Clustering_c.h +0 -123
- data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
- data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
- data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
- data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
- data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
- data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
- data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
- data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
- data/vendor/faiss/c_api/IndexShards_c.h +0 -39
- data/vendor/faiss/c_api/Index_c.cpp +0 -105
- data/vendor/faiss/c_api/Index_c.h +0 -183
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
- data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
- data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
- data/vendor/faiss/c_api/clone_index_c.h +0 -32
- data/vendor/faiss/c_api/error_c.h +0 -42
- data/vendor/faiss/c_api/error_impl.cpp +0 -27
- data/vendor/faiss/c_api/error_impl.h +0 -16
- data/vendor/faiss/c_api/faiss_c.h +0 -58
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
- data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
- data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
- data/vendor/faiss/c_api/index_factory_c.h +0 -30
- data/vendor/faiss/c_api/index_io_c.cpp +0 -42
- data/vendor/faiss/c_api/index_io_c.h +0 -50
- data/vendor/faiss/c_api/macros_impl.h +0 -110
- data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
- data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
- data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
- data/vendor/faiss/misc/test_blas.cpp +0 -87
- data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
- data/vendor/faiss/tests/test_merge.cpp +0 -260
- data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
- data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
- data/vendor/faiss/tests/test_params_override.cpp +0 -236
- data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
- data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
- data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
- data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,141 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <cstdint>
|
12
|
+
#include <cstdlib>
|
13
|
+
#include <cassert>
|
14
|
+
#include <cstring>
|
15
|
+
|
16
|
+
#include <algorithm>
|
17
|
+
|
18
|
+
#include <faiss/impl/platform_macros.h>
|
19
|
+
|
20
|
+
namespace faiss {
|
21
|
+
|
22
|
+
template<int A=32>
|
23
|
+
inline bool is_aligned_pointer(const void* x)
|
24
|
+
{
|
25
|
+
size_t xi = (size_t)x;
|
26
|
+
return xi % A == 0;
|
27
|
+
}
|
28
|
+
|
29
|
+
// class that manages suitably aligned arrays for SIMD
|
30
|
+
// T should be a POV type. The default alignment is 32 for AVX
|
31
|
+
template<class T, int A=32>
|
32
|
+
struct AlignedTableTightAlloc {
|
33
|
+
T * ptr;
|
34
|
+
size_t numel;
|
35
|
+
|
36
|
+
AlignedTableTightAlloc(): ptr(nullptr), numel(0)
|
37
|
+
{ }
|
38
|
+
|
39
|
+
explicit AlignedTableTightAlloc(size_t n): ptr(nullptr), numel(0)
|
40
|
+
{ resize(n); }
|
41
|
+
|
42
|
+
size_t itemsize() const {return sizeof(T); }
|
43
|
+
|
44
|
+
void resize(size_t n) {
|
45
|
+
if (numel == n) {
|
46
|
+
return;
|
47
|
+
}
|
48
|
+
T * new_ptr;
|
49
|
+
if (n > 0) {
|
50
|
+
int ret = posix_memalign((void**)&new_ptr, A, n * sizeof(T));
|
51
|
+
if (ret != 0) {
|
52
|
+
throw std::bad_alloc();
|
53
|
+
}
|
54
|
+
if (numel > 0) {
|
55
|
+
memcpy(new_ptr, ptr, sizeof(T) * std::min(numel, n));
|
56
|
+
}
|
57
|
+
} else {
|
58
|
+
new_ptr = nullptr;
|
59
|
+
}
|
60
|
+
numel = n;
|
61
|
+
posix_memalign_free(ptr);
|
62
|
+
ptr = new_ptr;
|
63
|
+
}
|
64
|
+
|
65
|
+
void clear() {memset(ptr, 0, nbytes()); }
|
66
|
+
size_t size() const {return numel; }
|
67
|
+
size_t nbytes() const {return numel * sizeof(T); }
|
68
|
+
|
69
|
+
T * get() {return ptr; }
|
70
|
+
const T * get() const {return ptr; }
|
71
|
+
T * data() {return ptr; }
|
72
|
+
const T * data() const {return ptr; }
|
73
|
+
T & operator [] (size_t i) {return ptr[i]; }
|
74
|
+
T operator [] (size_t i) const {return ptr[i]; }
|
75
|
+
|
76
|
+
~AlignedTableTightAlloc() {posix_memalign_free(ptr); }
|
77
|
+
|
78
|
+
AlignedTableTightAlloc<T, A> & operator =
|
79
|
+
(const AlignedTableTightAlloc<T, A> & other) {
|
80
|
+
resize(other.numel);
|
81
|
+
memcpy(ptr, other.ptr, sizeof(T) * numel);
|
82
|
+
return *this;
|
83
|
+
}
|
84
|
+
|
85
|
+
AlignedTableTightAlloc(const AlignedTableTightAlloc<T, A> & other) {
|
86
|
+
*this = other;
|
87
|
+
}
|
88
|
+
|
89
|
+
};
|
90
|
+
|
91
|
+
// same as AlignedTableTightAlloc, but with geometric re-allocation
|
92
|
+
template<class T, int A=32>
|
93
|
+
struct AlignedTable {
|
94
|
+
AlignedTableTightAlloc<T, A> tab;
|
95
|
+
size_t numel = 0;
|
96
|
+
|
97
|
+
static size_t round_capacity(size_t n) {
|
98
|
+
if (n == 0) {
|
99
|
+
return 0;
|
100
|
+
}
|
101
|
+
if (n < 8 * A) {
|
102
|
+
return 8 * A;
|
103
|
+
}
|
104
|
+
size_t capacity = 8 * A;
|
105
|
+
while (capacity < n) {
|
106
|
+
capacity *= 2;
|
107
|
+
}
|
108
|
+
return capacity;
|
109
|
+
}
|
110
|
+
|
111
|
+
AlignedTable() {}
|
112
|
+
|
113
|
+
explicit AlignedTable(size_t n):
|
114
|
+
tab(round_capacity(n)),
|
115
|
+
numel(n)
|
116
|
+
{ }
|
117
|
+
|
118
|
+
size_t itemsize() const {return sizeof(T); }
|
119
|
+
|
120
|
+
void resize(size_t n) {
|
121
|
+
tab.resize(round_capacity(n));
|
122
|
+
numel = n;
|
123
|
+
}
|
124
|
+
|
125
|
+
void clear() { tab.clear(); }
|
126
|
+
size_t size() const {return numel; }
|
127
|
+
size_t nbytes() const {return numel * sizeof(T); }
|
128
|
+
|
129
|
+
T * get() {return tab.get(); }
|
130
|
+
const T * get() const {return tab.get(); }
|
131
|
+
T * data() {return tab.get(); }
|
132
|
+
const T * data() const {return tab.get(); }
|
133
|
+
T & operator [] (size_t i) {return tab.ptr[i]; }
|
134
|
+
T operator [] (size_t i) const {return tab.ptr[i]; }
|
135
|
+
|
136
|
+
// assign and copy constructor should work as expected
|
137
|
+
|
138
|
+
};
|
139
|
+
|
140
|
+
|
141
|
+
} // namespace faiss
|
@@ -46,8 +46,7 @@ void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
|
|
46
46
|
for (size_t j = 0; j < nj; j++) {
|
47
47
|
T ip = ip_line [j];
|
48
48
|
if (C::cmp(simi[0], ip)) {
|
49
|
-
|
50
|
-
heap_push<C> (k, simi, idxi, ip, j + j0);
|
49
|
+
heap_replace_top<C> (k, simi, idxi, ip, j + j0);
|
51
50
|
}
|
52
51
|
}
|
53
52
|
}
|
@@ -74,8 +73,7 @@ void HeapArray<C>::addn_with_ids (
|
|
74
73
|
for (size_t j = 0; j < nj; j++) {
|
75
74
|
T ip = ip_line [j];
|
76
75
|
if (C::cmp(simi[0], ip)) {
|
77
|
-
|
78
|
-
heap_push<C> (k, simi, idxi, ip, id_line [j]);
|
76
|
+
heap_replace_top<C> (k, simi, idxi, ip, id_line [j]);
|
79
77
|
}
|
80
78
|
}
|
81
79
|
}
|
@@ -5,16 +5,18 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
8
|
|
10
9
|
/*
|
11
|
-
* C++ support for heaps. The set of functions is tailored for
|
12
|
-
*
|
10
|
+
* C++ support for heaps. The set of functions is tailored for efficient
|
11
|
+
* similarity search.
|
13
12
|
*
|
14
|
-
* There is no specific object for a heap, and the functions that
|
15
|
-
*
|
16
|
-
*
|
13
|
+
* There is no specific object for a heap, and the functions that operate on a
|
14
|
+
* single heap are inlined, because heaps are often small. More complex
|
15
|
+
* functions are implemented in Heaps.cpp
|
17
16
|
*
|
17
|
+
* All heap functions rely on a C template class that define the type of the
|
18
|
+
* keys and values and their ordering (increasing with CMax and decreasing with
|
19
|
+
* Cmin). The C types are defined in ordered_key_value.h
|
18
20
|
*/
|
19
21
|
|
20
22
|
|
@@ -31,51 +33,12 @@
|
|
31
33
|
|
32
34
|
#include <limits>
|
33
35
|
|
36
|
+
#include <faiss/utils/ordered_key_value.h>
|
34
37
|
|
35
38
|
namespace faiss {
|
36
39
|
|
37
|
-
/*******************************************************************
|
38
|
-
* C object: uniform handling of min and max heap
|
39
|
-
*******************************************************************/
|
40
|
-
|
41
|
-
/** The C object gives the type T of the values in the heap, the type
|
42
|
-
* of the keys, TI and the comparison that is done: > for the minheap
|
43
|
-
* and < for the maxheap. The neutral value will always be dropped in
|
44
|
-
* favor of any other value in the heap.
|
45
|
-
*/
|
46
|
-
|
47
|
-
template <typename T_, typename TI_>
|
48
|
-
struct CMax;
|
49
|
-
|
50
|
-
// traits of minheaps = heaps where the minimum value is stored on top
|
51
|
-
// useful to find the *max* values of an array
|
52
|
-
template <typename T_, typename TI_>
|
53
|
-
struct CMin {
|
54
|
-
typedef T_ T;
|
55
|
-
typedef TI_ TI;
|
56
|
-
typedef CMax<T_, TI_> Crev;
|
57
|
-
inline static bool cmp (T a, T b) {
|
58
|
-
return a < b;
|
59
|
-
}
|
60
|
-
inline static T neutral () {
|
61
|
-
return std::numeric_limits<T>::lowest();
|
62
|
-
}
|
63
|
-
};
|
64
40
|
|
65
41
|
|
66
|
-
template <typename T_, typename TI_>
|
67
|
-
struct CMax {
|
68
|
-
typedef T_ T;
|
69
|
-
typedef TI_ TI;
|
70
|
-
typedef CMin<T_, TI_> Crev;
|
71
|
-
inline static bool cmp (T a, T b) {
|
72
|
-
return a > b;
|
73
|
-
}
|
74
|
-
inline static T neutral () {
|
75
|
-
return std::numeric_limits<T>::max();
|
76
|
-
}
|
77
|
-
};
|
78
|
-
|
79
42
|
|
80
43
|
/*******************************************************************
|
81
44
|
* Basic heap ops: push and pop
|
@@ -142,6 +105,43 @@ void heap_push (size_t k,
|
|
142
105
|
|
143
106
|
|
144
107
|
|
108
|
+
/** Replace the top element from the heap defined by bh_val[0..k-1] and
|
109
|
+
* bh_ids[0..k-1].
|
110
|
+
*/
|
111
|
+
template <class C> inline
|
112
|
+
void heap_replace_top (size_t k,
|
113
|
+
typename C::T * bh_val, typename C::TI * bh_ids,
|
114
|
+
typename C::T val, typename C::TI ids)
|
115
|
+
{
|
116
|
+
bh_val--; /* Use 1-based indexing for easier node->child translation */
|
117
|
+
bh_ids--;
|
118
|
+
size_t i = 1, i1, i2;
|
119
|
+
while (1) {
|
120
|
+
i1 = i << 1;
|
121
|
+
i2 = i1 + 1;
|
122
|
+
if (i1 > k)
|
123
|
+
break;
|
124
|
+
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
|
125
|
+
if (C::cmp(val, bh_val[i1]))
|
126
|
+
break;
|
127
|
+
bh_val[i] = bh_val[i1];
|
128
|
+
bh_ids[i] = bh_ids[i1];
|
129
|
+
i = i1;
|
130
|
+
}
|
131
|
+
else {
|
132
|
+
if (C::cmp(val, bh_val[i2]))
|
133
|
+
break;
|
134
|
+
bh_val[i] = bh_val[i2];
|
135
|
+
bh_ids[i] = bh_ids[i2];
|
136
|
+
i = i2;
|
137
|
+
}
|
138
|
+
}
|
139
|
+
bh_val[i] = val;
|
140
|
+
bh_ids[i] = ids;
|
141
|
+
}
|
142
|
+
|
143
|
+
|
144
|
+
|
145
145
|
/* Partial instanciation for heaps with TI = int64_t */
|
146
146
|
|
147
147
|
template <typename T> inline
|
@@ -158,6 +158,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
|
158
158
|
}
|
159
159
|
|
160
160
|
|
161
|
+
template <typename T> inline
|
162
|
+
void minheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
163
|
+
{
|
164
|
+
heap_replace_top<CMin<T, int64_t> > (k, bh_val, bh_ids, val, ids);
|
165
|
+
}
|
166
|
+
|
167
|
+
|
161
168
|
template <typename T> inline
|
162
169
|
void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
|
163
170
|
{
|
@@ -172,6 +179,12 @@ void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
|
172
179
|
}
|
173
180
|
|
174
181
|
|
182
|
+
template <typename T> inline
|
183
|
+
void maxheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
184
|
+
{
|
185
|
+
heap_replace_top<CMax<T, int64_t> > (k, bh_val, bh_ids, val, ids);
|
186
|
+
}
|
187
|
+
|
175
188
|
|
176
189
|
/*******************************************************************
|
177
190
|
* Heap initialization
|
@@ -249,15 +262,13 @@ void heap_addn (size_t k,
|
|
249
262
|
if (ids)
|
250
263
|
for (i = 0; i < n; i++) {
|
251
264
|
if (C::cmp (bh_val[0], x[i])) {
|
252
|
-
|
253
|
-
heap_push<C> (k, bh_val, bh_ids, x[i], ids[i]);
|
265
|
+
heap_replace_top<C> (k, bh_val, bh_ids, x[i], ids[i]);
|
254
266
|
}
|
255
267
|
}
|
256
268
|
else
|
257
269
|
for (i = 0; i < n; i++) {
|
258
270
|
if (C::cmp (bh_val[0], x[i])) {
|
259
|
-
|
260
|
-
heap_push<C> (k, bh_val, bh_ids, x[i], i);
|
271
|
+
heap_replace_top<C> (k, bh_val, bh_ids, x[i], i);
|
261
272
|
}
|
262
273
|
}
|
263
274
|
}
|
@@ -19,6 +19,7 @@
|
|
19
19
|
|
20
20
|
#include <faiss/impl/AuxIndexStructures.h>
|
21
21
|
#include <faiss/impl/FaissAssert.h>
|
22
|
+
#include <faiss/impl/ResultHandler.h>
|
22
23
|
|
23
24
|
|
24
25
|
|
@@ -36,14 +37,6 @@ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
|
36
37
|
FINTEGER *lda, const float *b, FINTEGER *
|
37
38
|
ldb, float *beta, float *c, FINTEGER *ldc);
|
38
39
|
|
39
|
-
/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
|
40
|
-
|
41
|
-
int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
|
42
|
-
float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
|
43
|
-
|
44
|
-
int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
|
45
|
-
const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
|
46
|
-
float *beta, float *y, FINTEGER *incy);
|
47
40
|
|
48
41
|
}
|
49
42
|
|
@@ -58,34 +51,6 @@ namespace faiss {
|
|
58
51
|
|
59
52
|
|
60
53
|
|
61
|
-
/* Compute the inner product between a vector x and
|
62
|
-
a set of ny vectors y.
|
63
|
-
These functions are not intended to replace BLAS matrix-matrix, as they
|
64
|
-
would be significantly less efficient in this case. */
|
65
|
-
void fvec_inner_products_ny (float * ip,
|
66
|
-
const float * x,
|
67
|
-
const float * y,
|
68
|
-
size_t d, size_t ny)
|
69
|
-
{
|
70
|
-
// Not sure which one is fastest
|
71
|
-
#if 0
|
72
|
-
{
|
73
|
-
FINTEGER di = d;
|
74
|
-
FINTEGER nyi = ny;
|
75
|
-
float one = 1.0, zero = 0.0;
|
76
|
-
FINTEGER onei = 1;
|
77
|
-
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
78
|
-
}
|
79
|
-
#endif
|
80
|
-
for (size_t i = 0; i < ny; i++) {
|
81
|
-
ip[i] = fvec_inner_product (x, y, d);
|
82
|
-
y += d;
|
83
|
-
}
|
84
|
-
}
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
54
|
|
90
55
|
/* Compute the L2 norm of a set of nx vectors */
|
91
56
|
void fvec_norms_L2 (float * __restrict nr,
|
@@ -142,109 +107,112 @@ void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
|
142
107
|
* KNN functions
|
143
108
|
***************************************************************************/
|
144
109
|
|
110
|
+
namespace {
|
111
|
+
|
145
112
|
|
146
113
|
|
147
114
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
115
|
+
template<class ResultHandler>
|
116
|
+
void exhaustive_inner_product_seq (
|
117
|
+
const float * x,
|
118
|
+
const float * y,
|
119
|
+
size_t d, size_t nx, size_t ny,
|
120
|
+
ResultHandler &res)
|
152
121
|
{
|
153
|
-
size_t k = res->k;
|
154
122
|
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
155
123
|
|
156
124
|
check_period *= omp_get_max_threads();
|
157
125
|
|
126
|
+
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
127
|
+
|
158
128
|
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
159
129
|
size_t i1 = std::min(i0 + check_period, nx);
|
160
130
|
|
161
|
-
#pragma omp parallel
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
minheap_heapify (k, simi, idxi);
|
131
|
+
#pragma omp parallel
|
132
|
+
{
|
133
|
+
SingleResultHandler resi(res);
|
134
|
+
#pragma omp for
|
135
|
+
for (int64_t i = i0; i < i1; i++) {
|
136
|
+
const float * x_i = x + i * d;
|
137
|
+
const float * y_j = y;
|
170
138
|
|
171
|
-
|
172
|
-
float ip = fvec_inner_product (x_i, y_j, d);
|
139
|
+
resi.begin(i);
|
173
140
|
|
174
|
-
|
175
|
-
|
176
|
-
|
141
|
+
for (size_t j = 0; j < ny; j++) {
|
142
|
+
float ip = fvec_inner_product (x_i, y_j, d);
|
143
|
+
resi.add_result(ip, j);
|
144
|
+
y_j += d;
|
177
145
|
}
|
178
|
-
|
146
|
+
resi.end();
|
179
147
|
}
|
180
|
-
minheap_reorder (k, simi, idxi);
|
181
148
|
}
|
182
149
|
InterruptCallback::check ();
|
183
150
|
}
|
184
151
|
|
185
152
|
}
|
186
153
|
|
187
|
-
|
154
|
+
template<class ResultHandler>
|
155
|
+
void exhaustive_L2sqr_seq (
|
188
156
|
const float * x,
|
189
157
|
const float * y,
|
190
158
|
size_t d, size_t nx, size_t ny,
|
191
|
-
|
159
|
+
ResultHandler & res)
|
192
160
|
{
|
193
|
-
size_t k = res->k;
|
194
161
|
|
195
162
|
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
196
163
|
check_period *= omp_get_max_threads();
|
164
|
+
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
197
165
|
|
198
166
|
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
|
199
167
|
size_t i1 = std::min(i0 + check_period, nx);
|
200
168
|
|
201
|
-
#pragma omp parallel
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
if (disij < simi[0]) {
|
214
|
-
maxheap_pop (k, simi, idxi);
|
215
|
-
maxheap_push (k, simi, idxi, disij, j);
|
169
|
+
#pragma omp parallel
|
170
|
+
{
|
171
|
+
SingleResultHandler resi(res);
|
172
|
+
#pragma omp for
|
173
|
+
for (int64_t i = i0; i < i1; i++) {
|
174
|
+
const float * x_i = x + i * d;
|
175
|
+
const float * y_j = y;
|
176
|
+
resi.begin(i);
|
177
|
+
for (size_t j = 0; j < ny; j++) {
|
178
|
+
float disij = fvec_L2sqr (x_i, y_j, d);
|
179
|
+
resi.add_result(disij, j);
|
180
|
+
y_j += d;
|
216
181
|
}
|
217
|
-
|
182
|
+
resi.end();
|
218
183
|
}
|
219
|
-
maxheap_reorder (k, simi, idxi);
|
220
184
|
}
|
221
185
|
InterruptCallback::check ();
|
222
186
|
}
|
223
187
|
|
224
|
-
}
|
188
|
+
};
|
189
|
+
|
190
|
+
|
191
|
+
|
225
192
|
|
226
193
|
|
227
194
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
228
|
-
|
195
|
+
template<class ResultHandler>
|
196
|
+
void exhaustive_inner_product_blas (
|
229
197
|
const float * x,
|
230
198
|
const float * y,
|
231
199
|
size_t d, size_t nx, size_t ny,
|
232
|
-
|
200
|
+
ResultHandler & res)
|
233
201
|
{
|
234
|
-
res->heapify ();
|
235
|
-
|
236
202
|
// BLAS does not like empty matrices
|
237
203
|
if (nx == 0 || ny == 0) return;
|
238
204
|
|
239
205
|
/* block sizes */
|
240
|
-
const size_t bs_x =
|
241
|
-
|
206
|
+
const size_t bs_x = distance_compute_blas_query_bs;
|
207
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
242
208
|
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
243
209
|
|
244
210
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
245
211
|
size_t i1 = i0 + bs_x;
|
246
212
|
if(i1 > nx) i1 = nx;
|
247
213
|
|
214
|
+
res.begin_multiple(i0, i1);
|
215
|
+
|
248
216
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
249
217
|
size_t j1 = j0 + bs_y;
|
250
218
|
if (j1 > ny) j1 = ny;
|
@@ -258,46 +226,54 @@ static void knn_inner_product_blas (
|
|
258
226
|
ip_block.get(), &nyi);
|
259
227
|
}
|
260
228
|
|
261
|
-
|
262
|
-
|
229
|
+
res.add_results(j0, j1, ip_block.get());
|
230
|
+
|
263
231
|
}
|
232
|
+
res.end_multiple();
|
264
233
|
InterruptCallback::check ();
|
234
|
+
|
265
235
|
}
|
266
|
-
res->reorder ();
|
267
236
|
}
|
268
237
|
|
238
|
+
|
239
|
+
|
240
|
+
|
269
241
|
// distance correction is an operator that can be applied to transform
|
270
242
|
// the distances
|
271
|
-
template<class
|
272
|
-
|
243
|
+
template<class ResultHandler>
|
244
|
+
void exhaustive_L2sqr_blas (
|
245
|
+
const float * x,
|
273
246
|
const float * y,
|
274
247
|
size_t d, size_t nx, size_t ny,
|
275
|
-
|
276
|
-
const
|
248
|
+
ResultHandler & res,
|
249
|
+
const float *y_norms = nullptr)
|
277
250
|
{
|
278
|
-
res->heapify ();
|
279
|
-
|
280
251
|
// BLAS does not like empty matrices
|
281
252
|
if (nx == 0 || ny == 0) return;
|
282
253
|
|
283
|
-
size_t k = res->k;
|
284
|
-
|
285
254
|
/* block sizes */
|
286
|
-
const size_t bs_x =
|
255
|
+
const size_t bs_x = distance_compute_blas_query_bs;
|
256
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
287
257
|
// const size_t bs_x = 16, bs_y = 16;
|
288
|
-
float
|
289
|
-
float
|
290
|
-
float
|
291
|
-
ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
|
258
|
+
std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
|
259
|
+
std::unique_ptr<float []> x_norms(new float[nx]);
|
260
|
+
std::unique_ptr<float []> del2;
|
292
261
|
|
293
|
-
fvec_norms_L2sqr (x_norms, x, d, nx);
|
294
|
-
fvec_norms_L2sqr (y_norms, y, d, ny);
|
262
|
+
fvec_norms_L2sqr (x_norms.get(), x, d, nx);
|
295
263
|
|
264
|
+
if (!y_norms) {
|
265
|
+
float *y_norms2 = new float[ny];
|
266
|
+
del2.reset(y_norms2);
|
267
|
+
fvec_norms_L2sqr (y_norms2, y, d, ny);
|
268
|
+
y_norms = y_norms2;
|
269
|
+
}
|
296
270
|
|
297
271
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
298
272
|
size_t i1 = i0 + bs_x;
|
299
273
|
if(i1 > nx) i1 = nx;
|
300
274
|
|
275
|
+
res.begin_multiple(i0, i1);
|
276
|
+
|
301
277
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
302
278
|
size_t j1 = j0 + bs_y;
|
303
279
|
if (j1 > ny) j1 = ny;
|
@@ -308,42 +284,34 @@ static void knn_L2sqr_blas (const float * x,
|
|
308
284
|
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
309
285
|
y + j0 * d, &di,
|
310
286
|
x + i0 * d, &di, &zero,
|
311
|
-
ip_block, &nyi);
|
287
|
+
ip_block.get(), &nyi);
|
312
288
|
}
|
313
289
|
|
314
|
-
/* collect minima */
|
315
|
-
#pragma omp parallel for
|
316
290
|
for (int64_t i = i0; i < i1; i++) {
|
317
|
-
float *
|
318
|
-
int64_t * __restrict idxi = res->get_ids (i);
|
319
|
-
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
291
|
+
float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
320
292
|
|
321
293
|
for (size_t j = j0; j < j1; j++) {
|
322
|
-
float ip = *ip_line
|
294
|
+
float ip = *ip_line;
|
323
295
|
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
324
296
|
|
325
297
|
// negative values can occur for identical vectors
|
326
298
|
// due to roundoff errors
|
327
299
|
if (dis < 0) dis = 0;
|
328
300
|
|
329
|
-
|
330
|
-
|
331
|
-
if (dis < simi[0]) {
|
332
|
-
maxheap_pop (k, simi, idxi);
|
333
|
-
maxheap_push (k, simi, idxi, dis, j);
|
334
|
-
}
|
301
|
+
*ip_line = dis;
|
302
|
+
ip_line++;
|
335
303
|
}
|
336
304
|
}
|
305
|
+
res.add_results(j0, j1, ip_block.get());
|
337
306
|
}
|
307
|
+
res.end_multiple();
|
338
308
|
InterruptCallback::check ();
|
339
309
|
}
|
340
|
-
res->reorder ();
|
341
|
-
|
342
310
|
}
|
343
311
|
|
344
312
|
|
345
313
|
|
346
|
-
|
314
|
+
} // anonymous namespace
|
347
315
|
|
348
316
|
|
349
317
|
|
@@ -354,58 +322,103 @@ static void knn_L2sqr_blas (const float * x,
|
|
354
322
|
*******************************************************/
|
355
323
|
|
356
324
|
int distance_compute_blas_threshold = 20;
|
325
|
+
int distance_compute_blas_query_bs = 4096;
|
326
|
+
int distance_compute_blas_database_bs = 1024;
|
327
|
+
int distance_compute_min_k_reservoir = 100;
|
357
328
|
|
358
329
|
void knn_inner_product (const float * x,
|
359
330
|
const float * y,
|
360
331
|
size_t d, size_t nx, size_t ny,
|
361
|
-
float_minheap_array_t *
|
332
|
+
float_minheap_array_t * ha)
|
362
333
|
{
|
363
|
-
if (
|
364
|
-
|
334
|
+
if (ha->k < distance_compute_min_k_reservoir) {
|
335
|
+
HeapResultHandler<CMin<float, int64_t>> res(
|
336
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
337
|
+
if (nx < distance_compute_blas_threshold) {
|
338
|
+
exhaustive_inner_product_seq (x, y, d, nx, ny, res);
|
339
|
+
} else {
|
340
|
+
exhaustive_inner_product_blas (x, y, d, nx, ny, res);
|
341
|
+
}
|
365
342
|
} else {
|
366
|
-
|
343
|
+
ReservoirResultHandler<CMin<float, int64_t>> res(
|
344
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
345
|
+
if (nx < distance_compute_blas_threshold) {
|
346
|
+
exhaustive_inner_product_seq (x, y, d, nx, ny, res);
|
347
|
+
} else {
|
348
|
+
exhaustive_inner_product_blas (x, y, d, nx, ny, res);
|
349
|
+
}
|
367
350
|
}
|
368
351
|
}
|
369
352
|
|
370
353
|
|
371
354
|
|
372
|
-
|
373
|
-
|
374
|
-
|
355
|
+
|
356
|
+
void knn_L2sqr (
|
357
|
+
const float * x,
|
358
|
+
const float * y,
|
359
|
+
size_t d, size_t nx, size_t ny,
|
360
|
+
float_maxheap_array_t * ha,
|
361
|
+
const float *y_norm2
|
362
|
+
) {
|
363
|
+
|
364
|
+
if (ha->k < distance_compute_min_k_reservoir) {
|
365
|
+
HeapResultHandler<CMax<float, int64_t>> res(
|
366
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
367
|
+
|
368
|
+
if (nx < distance_compute_blas_threshold) {
|
369
|
+
exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
|
370
|
+
} else {
|
371
|
+
exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
|
372
|
+
}
|
373
|
+
} else {
|
374
|
+
ReservoirResultHandler<CMax<float, int64_t>> res(
|
375
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
376
|
+
if (nx < distance_compute_blas_threshold) {
|
377
|
+
exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
|
378
|
+
} else {
|
379
|
+
exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
|
380
|
+
}
|
375
381
|
}
|
376
|
-
}
|
382
|
+
}
|
377
383
|
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
384
|
+
|
385
|
+
/***************************************************************************
|
386
|
+
* Range search
|
387
|
+
***************************************************************************/
|
388
|
+
|
389
|
+
|
390
|
+
|
391
|
+
|
392
|
+
void range_search_L2sqr (
|
393
|
+
const float * x,
|
394
|
+
const float * y,
|
395
|
+
size_t d, size_t nx, size_t ny,
|
396
|
+
float radius,
|
397
|
+
RangeSearchResult *res)
|
382
398
|
{
|
399
|
+
RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
|
383
400
|
if (nx < distance_compute_blas_threshold) {
|
384
|
-
|
401
|
+
exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
|
385
402
|
} else {
|
386
|
-
|
387
|
-
knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
|
403
|
+
exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
|
388
404
|
}
|
389
405
|
}
|
390
406
|
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
void knn_L2sqr_base_shift (
|
399
|
-
const float * x,
|
400
|
-
const float * y,
|
401
|
-
size_t d, size_t nx, size_t ny,
|
402
|
-
float_maxheap_array_t * res,
|
403
|
-
const float *base_shift)
|
407
|
+
void range_search_inner_product (
|
408
|
+
const float * x,
|
409
|
+
const float * y,
|
410
|
+
size_t d, size_t nx, size_t ny,
|
411
|
+
float radius,
|
412
|
+
RangeSearchResult *res)
|
404
413
|
{
|
405
|
-
BaseShiftDistanceCorrection corr = {base_shift};
|
406
|
-
knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
|
407
|
-
}
|
408
414
|
|
415
|
+
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
|
416
|
+
if (nx < distance_compute_blas_threshold) {
|
417
|
+
exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
|
418
|
+
} else {
|
419
|
+
exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
|
420
|
+
}
|
421
|
+
}
|
409
422
|
|
410
423
|
|
411
424
|
/***************************************************************************
|
@@ -509,8 +522,7 @@ void knn_inner_products_by_idx (const float * x,
|
|
509
522
|
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
|
510
523
|
|
511
524
|
if (ip > simi[0]) {
|
512
|
-
|
513
|
-
minheap_push (k, simi, idxi, ip, idsi[j]);
|
525
|
+
minheap_replace_top (k, simi, idxi, ip, idsi[j]);
|
514
526
|
}
|
515
527
|
}
|
516
528
|
minheap_reorder (k, simi, idxi);
|
@@ -537,8 +549,7 @@ void knn_L2sqr_by_idx (const float * x,
|
|
537
549
|
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
|
538
550
|
|
539
551
|
if (disij < simi[0]) {
|
540
|
-
|
541
|
-
maxheap_push (k, simi, idxi, disij, idsi[j]);
|
552
|
+
maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
|
542
553
|
}
|
543
554
|
}
|
544
555
|
maxheap_reorder (res->k, simi, idxi);
|
@@ -550,172 +561,6 @@ void knn_L2sqr_by_idx (const float * x,
|
|
550
561
|
|
551
562
|
|
552
563
|
|
553
|
-
/***************************************************************************
|
554
|
-
* Range search
|
555
|
-
***************************************************************************/
|
556
|
-
|
557
|
-
/** Find the nearest neighbors for nx queries in a set of ny vectors
|
558
|
-
* compute_l2 = compute pairwise squared L2 distance rather than inner prod
|
559
|
-
*/
|
560
|
-
template <bool compute_l2>
|
561
|
-
static void range_search_blas (
|
562
|
-
const float * x,
|
563
|
-
const float * y,
|
564
|
-
size_t d, size_t nx, size_t ny,
|
565
|
-
float radius,
|
566
|
-
RangeSearchResult *result)
|
567
|
-
{
|
568
|
-
|
569
|
-
// BLAS does not like empty matrices
|
570
|
-
if (nx == 0 || ny == 0) return;
|
571
|
-
|
572
|
-
/* block sizes */
|
573
|
-
const size_t bs_x = 4096, bs_y = 1024;
|
574
|
-
// const size_t bs_x = 16, bs_y = 16;
|
575
|
-
float *ip_block = new float[bs_x * bs_y];
|
576
|
-
ScopeDeleter<float> del0(ip_block);
|
577
|
-
|
578
|
-
float *x_norms = nullptr, *y_norms = nullptr;
|
579
|
-
ScopeDeleter<float> del1, del2;
|
580
|
-
if (compute_l2) {
|
581
|
-
x_norms = new float[nx];
|
582
|
-
del1.set (x_norms);
|
583
|
-
fvec_norms_L2sqr (x_norms, x, d, nx);
|
584
|
-
|
585
|
-
y_norms = new float[ny];
|
586
|
-
del2.set (y_norms);
|
587
|
-
fvec_norms_L2sqr (y_norms, y, d, ny);
|
588
|
-
}
|
589
|
-
|
590
|
-
std::vector <RangeSearchPartialResult *> partial_results;
|
591
|
-
|
592
|
-
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
593
|
-
size_t j1 = j0 + bs_y;
|
594
|
-
if (j1 > ny) j1 = ny;
|
595
|
-
RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
|
596
|
-
partial_results.push_back (pres);
|
597
|
-
|
598
|
-
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
599
|
-
size_t i1 = i0 + bs_x;
|
600
|
-
if(i1 > nx) i1 = nx;
|
601
|
-
|
602
|
-
/* compute the actual dot products */
|
603
|
-
{
|
604
|
-
float one = 1, zero = 0;
|
605
|
-
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
606
|
-
sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
|
607
|
-
y + j0 * d, &di,
|
608
|
-
x + i0 * d, &di, &zero,
|
609
|
-
ip_block, &nyi);
|
610
|
-
}
|
611
|
-
|
612
|
-
|
613
|
-
for (size_t i = i0; i < i1; i++) {
|
614
|
-
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
|
615
|
-
|
616
|
-
RangeQueryResult & qres = pres->new_result (i);
|
617
|
-
|
618
|
-
for (size_t j = j0; j < j1; j++) {
|
619
|
-
float ip = *ip_line++;
|
620
|
-
if (compute_l2) {
|
621
|
-
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
622
|
-
if (dis < radius) {
|
623
|
-
qres.add (dis, j);
|
624
|
-
}
|
625
|
-
} else {
|
626
|
-
if (ip > radius) {
|
627
|
-
qres.add (ip, j);
|
628
|
-
}
|
629
|
-
}
|
630
|
-
}
|
631
|
-
}
|
632
|
-
}
|
633
|
-
InterruptCallback::check ();
|
634
|
-
}
|
635
|
-
|
636
|
-
RangeSearchPartialResult::merge (partial_results);
|
637
|
-
}
|
638
|
-
|
639
|
-
|
640
|
-
template <bool compute_l2>
|
641
|
-
static void range_search_sse (const float * x,
|
642
|
-
const float * y,
|
643
|
-
size_t d, size_t nx, size_t ny,
|
644
|
-
float radius,
|
645
|
-
RangeSearchResult *res)
|
646
|
-
{
|
647
|
-
|
648
|
-
#pragma omp parallel
|
649
|
-
{
|
650
|
-
RangeSearchPartialResult pres (res);
|
651
|
-
|
652
|
-
#pragma omp for
|
653
|
-
for (int64_t i = 0; i < nx; i++) {
|
654
|
-
const float * x_ = x + i * d;
|
655
|
-
const float * y_ = y;
|
656
|
-
size_t j;
|
657
|
-
|
658
|
-
RangeQueryResult & qres = pres.new_result (i);
|
659
|
-
|
660
|
-
for (j = 0; j < ny; j++) {
|
661
|
-
if (compute_l2) {
|
662
|
-
float disij = fvec_L2sqr (x_, y_, d);
|
663
|
-
if (disij < radius) {
|
664
|
-
qres.add (disij, j);
|
665
|
-
}
|
666
|
-
} else {
|
667
|
-
float ip = fvec_inner_product (x_, y_, d);
|
668
|
-
if (ip > radius) {
|
669
|
-
qres.add (ip, j);
|
670
|
-
}
|
671
|
-
}
|
672
|
-
y_ += d;
|
673
|
-
}
|
674
|
-
|
675
|
-
}
|
676
|
-
pres.finalize ();
|
677
|
-
}
|
678
|
-
|
679
|
-
// check just at the end because the use case is typically just
|
680
|
-
// when the nb of queries is low.
|
681
|
-
InterruptCallback::check();
|
682
|
-
}
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
void range_search_L2sqr (
|
689
|
-
const float * x,
|
690
|
-
const float * y,
|
691
|
-
size_t d, size_t nx, size_t ny,
|
692
|
-
float radius,
|
693
|
-
RangeSearchResult *res)
|
694
|
-
{
|
695
|
-
|
696
|
-
if (nx < distance_compute_blas_threshold) {
|
697
|
-
range_search_sse<true> (x, y, d, nx, ny, radius, res);
|
698
|
-
} else {
|
699
|
-
range_search_blas<true> (x, y, d, nx, ny, radius, res);
|
700
|
-
}
|
701
|
-
}
|
702
|
-
|
703
|
-
void range_search_inner_product (
|
704
|
-
const float * x,
|
705
|
-
const float * y,
|
706
|
-
size_t d, size_t nx, size_t ny,
|
707
|
-
float radius,
|
708
|
-
RangeSearchResult *res)
|
709
|
-
{
|
710
|
-
|
711
|
-
if (nx < distance_compute_blas_threshold) {
|
712
|
-
range_search_sse<false> (x, y, d, nx, ny, radius, res);
|
713
|
-
} else {
|
714
|
-
range_search_blas<false> (x, y, d, nx, ny, radius, res);
|
715
|
-
}
|
716
|
-
}
|
717
|
-
|
718
|
-
|
719
564
|
void pairwise_L2sqr (int64_t d,
|
720
565
|
int64_t nq, const float *xq,
|
721
566
|
int64_t nb, const float *xb,
|