faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,141 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #pragma once
10
+
11
+ #include <cstdint>
12
+ #include <cstdlib>
13
+ #include <cassert>
14
+ #include <cstring>
15
+
16
+ #include <algorithm>
17
+
18
+ #include <faiss/impl/platform_macros.h>
19
+
20
+ namespace faiss {
21
+
22
+ template<int A=32>
23
+ inline bool is_aligned_pointer(const void* x)
24
+ {
25
+ size_t xi = (size_t)x;
26
+ return xi % A == 0;
27
+ }
28
+
29
+ // class that manages suitably aligned arrays for SIMD
30
+ // T should be a POV type. The default alignment is 32 for AVX
31
+ template<class T, int A=32>
32
+ struct AlignedTableTightAlloc {
33
+ T * ptr;
34
+ size_t numel;
35
+
36
+ AlignedTableTightAlloc(): ptr(nullptr), numel(0)
37
+ { }
38
+
39
+ explicit AlignedTableTightAlloc(size_t n): ptr(nullptr), numel(0)
40
+ { resize(n); }
41
+
42
+ size_t itemsize() const {return sizeof(T); }
43
+
44
+ void resize(size_t n) {
45
+ if (numel == n) {
46
+ return;
47
+ }
48
+ T * new_ptr;
49
+ if (n > 0) {
50
+ int ret = posix_memalign((void**)&new_ptr, A, n * sizeof(T));
51
+ if (ret != 0) {
52
+ throw std::bad_alloc();
53
+ }
54
+ if (numel > 0) {
55
+ memcpy(new_ptr, ptr, sizeof(T) * std::min(numel, n));
56
+ }
57
+ } else {
58
+ new_ptr = nullptr;
59
+ }
60
+ numel = n;
61
+ posix_memalign_free(ptr);
62
+ ptr = new_ptr;
63
+ }
64
+
65
+ void clear() {memset(ptr, 0, nbytes()); }
66
+ size_t size() const {return numel; }
67
+ size_t nbytes() const {return numel * sizeof(T); }
68
+
69
+ T * get() {return ptr; }
70
+ const T * get() const {return ptr; }
71
+ T * data() {return ptr; }
72
+ const T * data() const {return ptr; }
73
+ T & operator [] (size_t i) {return ptr[i]; }
74
+ T operator [] (size_t i) const {return ptr[i]; }
75
+
76
+ ~AlignedTableTightAlloc() {posix_memalign_free(ptr); }
77
+
78
+ AlignedTableTightAlloc<T, A> & operator =
79
+ (const AlignedTableTightAlloc<T, A> & other) {
80
+ resize(other.numel);
81
+ memcpy(ptr, other.ptr, sizeof(T) * numel);
82
+ return *this;
83
+ }
84
+
85
+ AlignedTableTightAlloc(const AlignedTableTightAlloc<T, A> & other) {
86
+ *this = other;
87
+ }
88
+
89
+ };
90
+
91
+ // same as AlignedTableTightAlloc, but with geometric re-allocation
92
+ template<class T, int A=32>
93
+ struct AlignedTable {
94
+ AlignedTableTightAlloc<T, A> tab;
95
+ size_t numel = 0;
96
+
97
+ static size_t round_capacity(size_t n) {
98
+ if (n == 0) {
99
+ return 0;
100
+ }
101
+ if (n < 8 * A) {
102
+ return 8 * A;
103
+ }
104
+ size_t capacity = 8 * A;
105
+ while (capacity < n) {
106
+ capacity *= 2;
107
+ }
108
+ return capacity;
109
+ }
110
+
111
+ AlignedTable() {}
112
+
113
+ explicit AlignedTable(size_t n):
114
+ tab(round_capacity(n)),
115
+ numel(n)
116
+ { }
117
+
118
+ size_t itemsize() const {return sizeof(T); }
119
+
120
+ void resize(size_t n) {
121
+ tab.resize(round_capacity(n));
122
+ numel = n;
123
+ }
124
+
125
+ void clear() { tab.clear(); }
126
+ size_t size() const {return numel; }
127
+ size_t nbytes() const {return numel * sizeof(T); }
128
+
129
+ T * get() {return tab.get(); }
130
+ const T * get() const {return tab.get(); }
131
+ T * data() {return tab.get(); }
132
+ const T * data() const {return tab.get(); }
133
+ T & operator [] (size_t i) {return tab.ptr[i]; }
134
+ T operator [] (size_t i) const {return tab.ptr[i]; }
135
+
136
+ // assign and copy constructor should work as expected
137
+
138
+ };
139
+
140
+
141
+ } // namespace faiss
@@ -46,8 +46,7 @@ void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
46
46
  for (size_t j = 0; j < nj; j++) {
47
47
  T ip = ip_line [j];
48
48
  if (C::cmp(simi[0], ip)) {
49
- heap_pop<C> (k, simi, idxi);
50
- heap_push<C> (k, simi, idxi, ip, j + j0);
49
+ heap_replace_top<C> (k, simi, idxi, ip, j + j0);
51
50
  }
52
51
  }
53
52
  }
@@ -74,8 +73,7 @@ void HeapArray<C>::addn_with_ids (
74
73
  for (size_t j = 0; j < nj; j++) {
75
74
  T ip = ip_line [j];
76
75
  if (C::cmp(simi[0], ip)) {
77
- heap_pop<C> (k, simi, idxi);
78
- heap_push<C> (k, simi, idxi, ip, id_line [j]);
76
+ heap_replace_top<C> (k, simi, idxi, ip, id_line [j]);
79
77
  }
80
78
  }
81
79
  }
@@ -5,16 +5,18 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
8
 
10
9
  /*
11
- * C++ support for heaps. The set of functions is tailored for
12
- * efficient similarity search.
10
+ * C++ support for heaps. The set of functions is tailored for efficient
11
+ * similarity search.
13
12
  *
14
- * There is no specific object for a heap, and the functions that
15
- * operate on a signle heap are inlined, because heaps are often
16
- * small. More complex functions are implemented in Heaps.cpp
13
+ * There is no specific object for a heap, and the functions that operate on a
14
+ * single heap are inlined, because heaps are often small. More complex
15
+ * functions are implemented in Heaps.cpp
17
16
  *
17
+ * All heap functions rely on a C template class that define the type of the
18
+ * keys and values and their ordering (increasing with CMax and decreasing with
19
+ * Cmin). The C types are defined in ordered_key_value.h
18
20
  */
19
21
 
20
22
 
@@ -31,51 +33,12 @@
31
33
 
32
34
  #include <limits>
33
35
 
36
+ #include <faiss/utils/ordered_key_value.h>
34
37
 
35
38
  namespace faiss {
36
39
 
37
- /*******************************************************************
38
- * C object: uniform handling of min and max heap
39
- *******************************************************************/
40
-
41
- /** The C object gives the type T of the values in the heap, the type
42
- * of the keys, TI and the comparison that is done: > for the minheap
43
- * and < for the maxheap. The neutral value will always be dropped in
44
- * favor of any other value in the heap.
45
- */
46
-
47
- template <typename T_, typename TI_>
48
- struct CMax;
49
-
50
- // traits of minheaps = heaps where the minimum value is stored on top
51
- // useful to find the *max* values of an array
52
- template <typename T_, typename TI_>
53
- struct CMin {
54
- typedef T_ T;
55
- typedef TI_ TI;
56
- typedef CMax<T_, TI_> Crev;
57
- inline static bool cmp (T a, T b) {
58
- return a < b;
59
- }
60
- inline static T neutral () {
61
- return std::numeric_limits<T>::lowest();
62
- }
63
- };
64
40
 
65
41
 
66
- template <typename T_, typename TI_>
67
- struct CMax {
68
- typedef T_ T;
69
- typedef TI_ TI;
70
- typedef CMin<T_, TI_> Crev;
71
- inline static bool cmp (T a, T b) {
72
- return a > b;
73
- }
74
- inline static T neutral () {
75
- return std::numeric_limits<T>::max();
76
- }
77
- };
78
-
79
42
 
80
43
  /*******************************************************************
81
44
  * Basic heap ops: push and pop
@@ -142,6 +105,43 @@ void heap_push (size_t k,
142
105
 
143
106
 
144
107
 
108
+ /** Replace the top element from the heap defined by bh_val[0..k-1] and
109
+ * bh_ids[0..k-1].
110
+ */
111
+ template <class C> inline
112
+ void heap_replace_top (size_t k,
113
+ typename C::T * bh_val, typename C::TI * bh_ids,
114
+ typename C::T val, typename C::TI ids)
115
+ {
116
+ bh_val--; /* Use 1-based indexing for easier node->child translation */
117
+ bh_ids--;
118
+ size_t i = 1, i1, i2;
119
+ while (1) {
120
+ i1 = i << 1;
121
+ i2 = i1 + 1;
122
+ if (i1 > k)
123
+ break;
124
+ if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
125
+ if (C::cmp(val, bh_val[i1]))
126
+ break;
127
+ bh_val[i] = bh_val[i1];
128
+ bh_ids[i] = bh_ids[i1];
129
+ i = i1;
130
+ }
131
+ else {
132
+ if (C::cmp(val, bh_val[i2]))
133
+ break;
134
+ bh_val[i] = bh_val[i2];
135
+ bh_ids[i] = bh_ids[i2];
136
+ i = i2;
137
+ }
138
+ }
139
+ bh_val[i] = val;
140
+ bh_ids[i] = ids;
141
+ }
142
+
143
+
144
+
145
145
  /* Partial instanciation for heaps with TI = int64_t */
146
146
 
147
147
  template <typename T> inline
@@ -158,6 +158,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
158
158
  }
159
159
 
160
160
 
161
+ template <typename T> inline
162
+ void minheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
163
+ {
164
+ heap_replace_top<CMin<T, int64_t> > (k, bh_val, bh_ids, val, ids);
165
+ }
166
+
167
+
161
168
  template <typename T> inline
162
169
  void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
163
170
  {
@@ -172,6 +179,12 @@ void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
172
179
  }
173
180
 
174
181
 
182
+ template <typename T> inline
183
+ void maxheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
184
+ {
185
+ heap_replace_top<CMax<T, int64_t> > (k, bh_val, bh_ids, val, ids);
186
+ }
187
+
175
188
 
176
189
  /*******************************************************************
177
190
  * Heap initialization
@@ -249,15 +262,13 @@ void heap_addn (size_t k,
249
262
  if (ids)
250
263
  for (i = 0; i < n; i++) {
251
264
  if (C::cmp (bh_val[0], x[i])) {
252
- heap_pop<C> (k, bh_val, bh_ids);
253
- heap_push<C> (k, bh_val, bh_ids, x[i], ids[i]);
265
+ heap_replace_top<C> (k, bh_val, bh_ids, x[i], ids[i]);
254
266
  }
255
267
  }
256
268
  else
257
269
  for (i = 0; i < n; i++) {
258
270
  if (C::cmp (bh_val[0], x[i])) {
259
- heap_pop<C> (k, bh_val, bh_ids);
260
- heap_push<C> (k, bh_val, bh_ids, x[i], i);
271
+ heap_replace_top<C> (k, bh_val, bh_ids, x[i], i);
261
272
  }
262
273
  }
263
274
  }
@@ -19,6 +19,7 @@
19
19
 
20
20
  #include <faiss/impl/AuxIndexStructures.h>
21
21
  #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/ResultHandler.h>
22
23
 
23
24
 
24
25
 
@@ -36,14 +37,6 @@ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
36
37
  FINTEGER *lda, const float *b, FINTEGER *
37
38
  ldb, float *beta, float *c, FINTEGER *ldc);
38
39
 
39
- /* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
40
-
41
- int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
42
- float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
43
-
44
- int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
45
- const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
46
- float *beta, float *y, FINTEGER *incy);
47
40
 
48
41
  }
49
42
 
@@ -58,34 +51,6 @@ namespace faiss {
58
51
 
59
52
 
60
53
 
61
- /* Compute the inner product between a vector x and
62
- a set of ny vectors y.
63
- These functions are not intended to replace BLAS matrix-matrix, as they
64
- would be significantly less efficient in this case. */
65
- void fvec_inner_products_ny (float * ip,
66
- const float * x,
67
- const float * y,
68
- size_t d, size_t ny)
69
- {
70
- // Not sure which one is fastest
71
- #if 0
72
- {
73
- FINTEGER di = d;
74
- FINTEGER nyi = ny;
75
- float one = 1.0, zero = 0.0;
76
- FINTEGER onei = 1;
77
- sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
78
- }
79
- #endif
80
- for (size_t i = 0; i < ny; i++) {
81
- ip[i] = fvec_inner_product (x, y, d);
82
- y += d;
83
- }
84
- }
85
-
86
-
87
-
88
-
89
54
 
90
55
  /* Compute the L2 norm of a set of nx vectors */
91
56
  void fvec_norms_L2 (float * __restrict nr,
@@ -142,109 +107,112 @@ void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
142
107
  * KNN functions
143
108
  ***************************************************************************/
144
109
 
110
+ namespace {
111
+
145
112
 
146
113
 
147
114
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
148
- static void knn_inner_product_sse (const float * x,
149
- const float * y,
150
- size_t d, size_t nx, size_t ny,
151
- float_minheap_array_t * res)
115
+ template<class ResultHandler>
116
+ void exhaustive_inner_product_seq (
117
+ const float * x,
118
+ const float * y,
119
+ size_t d, size_t nx, size_t ny,
120
+ ResultHandler &res)
152
121
  {
153
- size_t k = res->k;
154
122
  size_t check_period = InterruptCallback::get_period_hint (ny * d);
155
123
 
156
124
  check_period *= omp_get_max_threads();
157
125
 
126
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
127
+
158
128
  for (size_t i0 = 0; i0 < nx; i0 += check_period) {
159
129
  size_t i1 = std::min(i0 + check_period, nx);
160
130
 
161
- #pragma omp parallel for
162
- for (int64_t i = i0; i < i1; i++) {
163
- const float * x_i = x + i * d;
164
- const float * y_j = y;
165
-
166
- float * __restrict simi = res->get_val(i);
167
- int64_t * __restrict idxi = res->get_ids (i);
168
-
169
- minheap_heapify (k, simi, idxi);
131
+ #pragma omp parallel
132
+ {
133
+ SingleResultHandler resi(res);
134
+ #pragma omp for
135
+ for (int64_t i = i0; i < i1; i++) {
136
+ const float * x_i = x + i * d;
137
+ const float * y_j = y;
170
138
 
171
- for (size_t j = 0; j < ny; j++) {
172
- float ip = fvec_inner_product (x_i, y_j, d);
139
+ resi.begin(i);
173
140
 
174
- if (ip > simi[0]) {
175
- minheap_pop (k, simi, idxi);
176
- minheap_push (k, simi, idxi, ip, j);
141
+ for (size_t j = 0; j < ny; j++) {
142
+ float ip = fvec_inner_product (x_i, y_j, d);
143
+ resi.add_result(ip, j);
144
+ y_j += d;
177
145
  }
178
- y_j += d;
146
+ resi.end();
179
147
  }
180
- minheap_reorder (k, simi, idxi);
181
148
  }
182
149
  InterruptCallback::check ();
183
150
  }
184
151
 
185
152
  }
186
153
 
187
- static void knn_L2sqr_sse (
154
+ template<class ResultHandler>
155
+ void exhaustive_L2sqr_seq (
188
156
  const float * x,
189
157
  const float * y,
190
158
  size_t d, size_t nx, size_t ny,
191
- float_maxheap_array_t * res)
159
+ ResultHandler & res)
192
160
  {
193
- size_t k = res->k;
194
161
 
195
162
  size_t check_period = InterruptCallback::get_period_hint (ny * d);
196
163
  check_period *= omp_get_max_threads();
164
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
197
165
 
198
166
  for (size_t i0 = 0; i0 < nx; i0 += check_period) {
199
167
  size_t i1 = std::min(i0 + check_period, nx);
200
168
 
201
- #pragma omp parallel for
202
- for (int64_t i = i0; i < i1; i++) {
203
- const float * x_i = x + i * d;
204
- const float * y_j = y;
205
- size_t j;
206
- float * simi = res->get_val(i);
207
- int64_t * idxi = res->get_ids (i);
208
-
209
- maxheap_heapify (k, simi, idxi);
210
- for (j = 0; j < ny; j++) {
211
- float disij = fvec_L2sqr (x_i, y_j, d);
212
-
213
- if (disij < simi[0]) {
214
- maxheap_pop (k, simi, idxi);
215
- maxheap_push (k, simi, idxi, disij, j);
169
+ #pragma omp parallel
170
+ {
171
+ SingleResultHandler resi(res);
172
+ #pragma omp for
173
+ for (int64_t i = i0; i < i1; i++) {
174
+ const float * x_i = x + i * d;
175
+ const float * y_j = y;
176
+ resi.begin(i);
177
+ for (size_t j = 0; j < ny; j++) {
178
+ float disij = fvec_L2sqr (x_i, y_j, d);
179
+ resi.add_result(disij, j);
180
+ y_j += d;
216
181
  }
217
- y_j += d;
182
+ resi.end();
218
183
  }
219
- maxheap_reorder (k, simi, idxi);
220
184
  }
221
185
  InterruptCallback::check ();
222
186
  }
223
187
 
224
- }
188
+ };
189
+
190
+
191
+
225
192
 
226
193
 
227
194
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
228
- static void knn_inner_product_blas (
195
+ template<class ResultHandler>
196
+ void exhaustive_inner_product_blas (
229
197
  const float * x,
230
198
  const float * y,
231
199
  size_t d, size_t nx, size_t ny,
232
- float_minheap_array_t * res)
200
+ ResultHandler & res)
233
201
  {
234
- res->heapify ();
235
-
236
202
  // BLAS does not like empty matrices
237
203
  if (nx == 0 || ny == 0) return;
238
204
 
239
205
  /* block sizes */
240
- const size_t bs_x = 4096, bs_y = 1024;
241
- // const size_t bs_x = 16, bs_y = 16;
206
+ const size_t bs_x = distance_compute_blas_query_bs;
207
+ const size_t bs_y = distance_compute_blas_database_bs;
242
208
  std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
243
209
 
244
210
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
245
211
  size_t i1 = i0 + bs_x;
246
212
  if(i1 > nx) i1 = nx;
247
213
 
214
+ res.begin_multiple(i0, i1);
215
+
248
216
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
249
217
  size_t j1 = j0 + bs_y;
250
218
  if (j1 > ny) j1 = ny;
@@ -258,46 +226,54 @@ static void knn_inner_product_blas (
258
226
  ip_block.get(), &nyi);
259
227
  }
260
228
 
261
- /* collect maxima */
262
- res->addn (j1 - j0, ip_block.get(), j0, i0, i1 - i0);
229
+ res.add_results(j0, j1, ip_block.get());
230
+
263
231
  }
232
+ res.end_multiple();
264
233
  InterruptCallback::check ();
234
+
265
235
  }
266
- res->reorder ();
267
236
  }
268
237
 
238
+
239
+
240
+
269
241
  // distance correction is an operator that can be applied to transform
270
242
  // the distances
271
- template<class DistanceCorrection>
272
- static void knn_L2sqr_blas (const float * x,
243
+ template<class ResultHandler>
244
+ void exhaustive_L2sqr_blas (
245
+ const float * x,
273
246
  const float * y,
274
247
  size_t d, size_t nx, size_t ny,
275
- float_maxheap_array_t * res,
276
- const DistanceCorrection &corr)
248
+ ResultHandler & res,
249
+ const float *y_norms = nullptr)
277
250
  {
278
- res->heapify ();
279
-
280
251
  // BLAS does not like empty matrices
281
252
  if (nx == 0 || ny == 0) return;
282
253
 
283
- size_t k = res->k;
284
-
285
254
  /* block sizes */
286
- const size_t bs_x = 4096, bs_y = 1024;
255
+ const size_t bs_x = distance_compute_blas_query_bs;
256
+ const size_t bs_y = distance_compute_blas_database_bs;
287
257
  // const size_t bs_x = 16, bs_y = 16;
288
- float *ip_block = new float[bs_x * bs_y];
289
- float *x_norms = new float[nx];
290
- float *y_norms = new float[ny];
291
- ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
258
+ std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
259
+ std::unique_ptr<float []> x_norms(new float[nx]);
260
+ std::unique_ptr<float []> del2;
292
261
 
293
- fvec_norms_L2sqr (x_norms, x, d, nx);
294
- fvec_norms_L2sqr (y_norms, y, d, ny);
262
+ fvec_norms_L2sqr (x_norms.get(), x, d, nx);
295
263
 
264
+ if (!y_norms) {
265
+ float *y_norms2 = new float[ny];
266
+ del2.reset(y_norms2);
267
+ fvec_norms_L2sqr (y_norms2, y, d, ny);
268
+ y_norms = y_norms2;
269
+ }
296
270
 
297
271
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
298
272
  size_t i1 = i0 + bs_x;
299
273
  if(i1 > nx) i1 = nx;
300
274
 
275
+ res.begin_multiple(i0, i1);
276
+
301
277
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
302
278
  size_t j1 = j0 + bs_y;
303
279
  if (j1 > ny) j1 = ny;
@@ -308,42 +284,34 @@ static void knn_L2sqr_blas (const float * x,
308
284
  sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
309
285
  y + j0 * d, &di,
310
286
  x + i0 * d, &di, &zero,
311
- ip_block, &nyi);
287
+ ip_block.get(), &nyi);
312
288
  }
313
289
 
314
- /* collect minima */
315
- #pragma omp parallel for
316
290
  for (int64_t i = i0; i < i1; i++) {
317
- float * __restrict simi = res->get_val(i);
318
- int64_t * __restrict idxi = res->get_ids (i);
319
- const float *ip_line = ip_block + (i - i0) * (j1 - j0);
291
+ float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
320
292
 
321
293
  for (size_t j = j0; j < j1; j++) {
322
- float ip = *ip_line++;
294
+ float ip = *ip_line;
323
295
  float dis = x_norms[i] + y_norms[j] - 2 * ip;
324
296
 
325
297
  // negative values can occur for identical vectors
326
298
  // due to roundoff errors
327
299
  if (dis < 0) dis = 0;
328
300
 
329
- dis = corr (dis, i, j);
330
-
331
- if (dis < simi[0]) {
332
- maxheap_pop (k, simi, idxi);
333
- maxheap_push (k, simi, idxi, dis, j);
334
- }
301
+ *ip_line = dis;
302
+ ip_line++;
335
303
  }
336
304
  }
305
+ res.add_results(j0, j1, ip_block.get());
337
306
  }
307
+ res.end_multiple();
338
308
  InterruptCallback::check ();
339
309
  }
340
- res->reorder ();
341
-
342
310
  }
343
311
 
344
312
 
345
313
 
346
-
314
+ } // anonymous namespace
347
315
 
348
316
 
349
317
 
@@ -354,58 +322,103 @@ static void knn_L2sqr_blas (const float * x,
354
322
  *******************************************************/
355
323
 
356
324
  int distance_compute_blas_threshold = 20;
325
+ int distance_compute_blas_query_bs = 4096;
326
+ int distance_compute_blas_database_bs = 1024;
327
+ int distance_compute_min_k_reservoir = 100;
357
328
 
358
329
  void knn_inner_product (const float * x,
359
330
  const float * y,
360
331
  size_t d, size_t nx, size_t ny,
361
- float_minheap_array_t * res)
332
+ float_minheap_array_t * ha)
362
333
  {
363
- if (nx < distance_compute_blas_threshold) {
364
- knn_inner_product_sse (x, y, d, nx, ny, res);
334
+ if (ha->k < distance_compute_min_k_reservoir) {
335
+ HeapResultHandler<CMin<float, int64_t>> res(
336
+ ha->nh, ha->val, ha->ids, ha->k);
337
+ if (nx < distance_compute_blas_threshold) {
338
+ exhaustive_inner_product_seq (x, y, d, nx, ny, res);
339
+ } else {
340
+ exhaustive_inner_product_blas (x, y, d, nx, ny, res);
341
+ }
365
342
  } else {
366
- knn_inner_product_blas (x, y, d, nx, ny, res);
343
+ ReservoirResultHandler<CMin<float, int64_t>> res(
344
+ ha->nh, ha->val, ha->ids, ha->k);
345
+ if (nx < distance_compute_blas_threshold) {
346
+ exhaustive_inner_product_seq (x, y, d, nx, ny, res);
347
+ } else {
348
+ exhaustive_inner_product_blas (x, y, d, nx, ny, res);
349
+ }
367
350
  }
368
351
  }
369
352
 
370
353
 
371
354
 
372
- struct NopDistanceCorrection {
373
- float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
374
- return dis;
355
+
356
+ void knn_L2sqr (
357
+ const float * x,
358
+ const float * y,
359
+ size_t d, size_t nx, size_t ny,
360
+ float_maxheap_array_t * ha,
361
+ const float *y_norm2
362
+ ) {
363
+
364
+ if (ha->k < distance_compute_min_k_reservoir) {
365
+ HeapResultHandler<CMax<float, int64_t>> res(
366
+ ha->nh, ha->val, ha->ids, ha->k);
367
+
368
+ if (nx < distance_compute_blas_threshold) {
369
+ exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
370
+ } else {
371
+ exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
372
+ }
373
+ } else {
374
+ ReservoirResultHandler<CMax<float, int64_t>> res(
375
+ ha->nh, ha->val, ha->ids, ha->k);
376
+ if (nx < distance_compute_blas_threshold) {
377
+ exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
378
+ } else {
379
+ exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
380
+ }
375
381
  }
376
- };
382
+ }
377
383
 
378
- void knn_L2sqr (const float * x,
379
- const float * y,
380
- size_t d, size_t nx, size_t ny,
381
- float_maxheap_array_t * res)
384
+
385
+ /***************************************************************************
386
+ * Range search
387
+ ***************************************************************************/
388
+
389
+
390
+
391
+
392
+ void range_search_L2sqr (
393
+ const float * x,
394
+ const float * y,
395
+ size_t d, size_t nx, size_t ny,
396
+ float radius,
397
+ RangeSearchResult *res)
382
398
  {
399
+ RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
383
400
  if (nx < distance_compute_blas_threshold) {
384
- knn_L2sqr_sse (x, y, d, nx, ny, res);
401
+ exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
385
402
  } else {
386
- NopDistanceCorrection nop;
387
- knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
403
+ exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
388
404
  }
389
405
  }
390
406
 
391
- struct BaseShiftDistanceCorrection {
392
- const float *base_shift;
393
- float operator()(float dis, size_t /*qno*/, size_t bno) const {
394
- return dis - base_shift[bno];
395
- }
396
- };
397
-
398
- void knn_L2sqr_base_shift (
399
- const float * x,
400
- const float * y,
401
- size_t d, size_t nx, size_t ny,
402
- float_maxheap_array_t * res,
403
- const float *base_shift)
407
+ void range_search_inner_product (
408
+ const float * x,
409
+ const float * y,
410
+ size_t d, size_t nx, size_t ny,
411
+ float radius,
412
+ RangeSearchResult *res)
404
413
  {
405
- BaseShiftDistanceCorrection corr = {base_shift};
406
- knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
407
- }
408
414
 
415
+ RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
416
+ if (nx < distance_compute_blas_threshold) {
417
+ exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
418
+ } else {
419
+ exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
420
+ }
421
+ }
409
422
 
410
423
 
411
424
  /***************************************************************************
@@ -509,8 +522,7 @@ void knn_inner_products_by_idx (const float * x,
509
522
  float ip = fvec_inner_product (x_, y + d * idsi[j], d);
510
523
 
511
524
  if (ip > simi[0]) {
512
- minheap_pop (k, simi, idxi);
513
- minheap_push (k, simi, idxi, ip, idsi[j]);
525
+ minheap_replace_top (k, simi, idxi, ip, idsi[j]);
514
526
  }
515
527
  }
516
528
  minheap_reorder (k, simi, idxi);
@@ -537,8 +549,7 @@ void knn_L2sqr_by_idx (const float * x,
537
549
  float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
538
550
 
539
551
  if (disij < simi[0]) {
540
- maxheap_pop (k, simi, idxi);
541
- maxheap_push (k, simi, idxi, disij, idsi[j]);
552
+ maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
542
553
  }
543
554
  }
544
555
  maxheap_reorder (res->k, simi, idxi);
@@ -550,172 +561,6 @@ void knn_L2sqr_by_idx (const float * x,
550
561
 
551
562
 
552
563
 
553
- /***************************************************************************
554
- * Range search
555
- ***************************************************************************/
556
-
557
- /** Find the nearest neighbors for nx queries in a set of ny vectors
558
- * compute_l2 = compute pairwise squared L2 distance rather than inner prod
559
- */
560
- template <bool compute_l2>
561
- static void range_search_blas (
562
- const float * x,
563
- const float * y,
564
- size_t d, size_t nx, size_t ny,
565
- float radius,
566
- RangeSearchResult *result)
567
- {
568
-
569
- // BLAS does not like empty matrices
570
- if (nx == 0 || ny == 0) return;
571
-
572
- /* block sizes */
573
- const size_t bs_x = 4096, bs_y = 1024;
574
- // const size_t bs_x = 16, bs_y = 16;
575
- float *ip_block = new float[bs_x * bs_y];
576
- ScopeDeleter<float> del0(ip_block);
577
-
578
- float *x_norms = nullptr, *y_norms = nullptr;
579
- ScopeDeleter<float> del1, del2;
580
- if (compute_l2) {
581
- x_norms = new float[nx];
582
- del1.set (x_norms);
583
- fvec_norms_L2sqr (x_norms, x, d, nx);
584
-
585
- y_norms = new float[ny];
586
- del2.set (y_norms);
587
- fvec_norms_L2sqr (y_norms, y, d, ny);
588
- }
589
-
590
- std::vector <RangeSearchPartialResult *> partial_results;
591
-
592
- for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
593
- size_t j1 = j0 + bs_y;
594
- if (j1 > ny) j1 = ny;
595
- RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
596
- partial_results.push_back (pres);
597
-
598
- for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
599
- size_t i1 = i0 + bs_x;
600
- if(i1 > nx) i1 = nx;
601
-
602
- /* compute the actual dot products */
603
- {
604
- float one = 1, zero = 0;
605
- FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
606
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
607
- y + j0 * d, &di,
608
- x + i0 * d, &di, &zero,
609
- ip_block, &nyi);
610
- }
611
-
612
-
613
- for (size_t i = i0; i < i1; i++) {
614
- const float *ip_line = ip_block + (i - i0) * (j1 - j0);
615
-
616
- RangeQueryResult & qres = pres->new_result (i);
617
-
618
- for (size_t j = j0; j < j1; j++) {
619
- float ip = *ip_line++;
620
- if (compute_l2) {
621
- float dis = x_norms[i] + y_norms[j] - 2 * ip;
622
- if (dis < radius) {
623
- qres.add (dis, j);
624
- }
625
- } else {
626
- if (ip > radius) {
627
- qres.add (ip, j);
628
- }
629
- }
630
- }
631
- }
632
- }
633
- InterruptCallback::check ();
634
- }
635
-
636
- RangeSearchPartialResult::merge (partial_results);
637
- }
638
-
639
-
640
- template <bool compute_l2>
641
- static void range_search_sse (const float * x,
642
- const float * y,
643
- size_t d, size_t nx, size_t ny,
644
- float radius,
645
- RangeSearchResult *res)
646
- {
647
-
648
- #pragma omp parallel
649
- {
650
- RangeSearchPartialResult pres (res);
651
-
652
- #pragma omp for
653
- for (int64_t i = 0; i < nx; i++) {
654
- const float * x_ = x + i * d;
655
- const float * y_ = y;
656
- size_t j;
657
-
658
- RangeQueryResult & qres = pres.new_result (i);
659
-
660
- for (j = 0; j < ny; j++) {
661
- if (compute_l2) {
662
- float disij = fvec_L2sqr (x_, y_, d);
663
- if (disij < radius) {
664
- qres.add (disij, j);
665
- }
666
- } else {
667
- float ip = fvec_inner_product (x_, y_, d);
668
- if (ip > radius) {
669
- qres.add (ip, j);
670
- }
671
- }
672
- y_ += d;
673
- }
674
-
675
- }
676
- pres.finalize ();
677
- }
678
-
679
- // check just at the end because the use case is typically just
680
- // when the nb of queries is low.
681
- InterruptCallback::check();
682
- }
683
-
684
-
685
-
686
-
687
-
688
- void range_search_L2sqr (
689
- const float * x,
690
- const float * y,
691
- size_t d, size_t nx, size_t ny,
692
- float radius,
693
- RangeSearchResult *res)
694
- {
695
-
696
- if (nx < distance_compute_blas_threshold) {
697
- range_search_sse<true> (x, y, d, nx, ny, radius, res);
698
- } else {
699
- range_search_blas<true> (x, y, d, nx, ny, radius, res);
700
- }
701
- }
702
-
703
- void range_search_inner_product (
704
- const float * x,
705
- const float * y,
706
- size_t d, size_t nx, size_t ny,
707
- float radius,
708
- RangeSearchResult *res)
709
- {
710
-
711
- if (nx < distance_compute_blas_threshold) {
712
- range_search_sse<false> (x, y, d, nx, ny, radius, res);
713
- } else {
714
- range_search_blas<false> (x, y, d, nx, ny, radius, res);
715
- }
716
- }
717
-
718
-
719
564
  void pairwise_L2sqr (int64_t d,
720
565
  int64_t nq, const float *xq,
721
566
  int64_t nb, const float *xb,