faiss 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,141 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #pragma once
10
+
11
+ #include <cstdint>
12
+ #include <cstdlib>
13
+ #include <cassert>
14
+ #include <cstring>
15
+
16
+ #include <algorithm>
17
+
18
+ #include <faiss/impl/platform_macros.h>
19
+
20
+ namespace faiss {
21
+
22
+ template<int A=32>
23
+ inline bool is_aligned_pointer(const void* x)
24
+ {
25
+ size_t xi = (size_t)x;
26
+ return xi % A == 0;
27
+ }
28
+
29
+ // class that manages suitably aligned arrays for SIMD
30
+ // T should be a POV type. The default alignment is 32 for AVX
31
+ template<class T, int A=32>
32
+ struct AlignedTableTightAlloc {
33
+ T * ptr;
34
+ size_t numel;
35
+
36
+ AlignedTableTightAlloc(): ptr(nullptr), numel(0)
37
+ { }
38
+
39
+ explicit AlignedTableTightAlloc(size_t n): ptr(nullptr), numel(0)
40
+ { resize(n); }
41
+
42
+ size_t itemsize() const {return sizeof(T); }
43
+
44
+ void resize(size_t n) {
45
+ if (numel == n) {
46
+ return;
47
+ }
48
+ T * new_ptr;
49
+ if (n > 0) {
50
+ int ret = posix_memalign((void**)&new_ptr, A, n * sizeof(T));
51
+ if (ret != 0) {
52
+ throw std::bad_alloc();
53
+ }
54
+ if (numel > 0) {
55
+ memcpy(new_ptr, ptr, sizeof(T) * std::min(numel, n));
56
+ }
57
+ } else {
58
+ new_ptr = nullptr;
59
+ }
60
+ numel = n;
61
+ posix_memalign_free(ptr);
62
+ ptr = new_ptr;
63
+ }
64
+
65
+ void clear() {memset(ptr, 0, nbytes()); }
66
+ size_t size() const {return numel; }
67
+ size_t nbytes() const {return numel * sizeof(T); }
68
+
69
+ T * get() {return ptr; }
70
+ const T * get() const {return ptr; }
71
+ T * data() {return ptr; }
72
+ const T * data() const {return ptr; }
73
+ T & operator [] (size_t i) {return ptr[i]; }
74
+ T operator [] (size_t i) const {return ptr[i]; }
75
+
76
+ ~AlignedTableTightAlloc() {posix_memalign_free(ptr); }
77
+
78
+ AlignedTableTightAlloc<T, A> & operator =
79
+ (const AlignedTableTightAlloc<T, A> & other) {
80
+ resize(other.numel);
81
+ memcpy(ptr, other.ptr, sizeof(T) * numel);
82
+ return *this;
83
+ }
84
+
85
+ AlignedTableTightAlloc(const AlignedTableTightAlloc<T, A> & other) {
86
+ *this = other;
87
+ }
88
+
89
+ };
90
+
91
+ // same as AlignedTableTightAlloc, but with geometric re-allocation
92
+ template<class T, int A=32>
93
+ struct AlignedTable {
94
+ AlignedTableTightAlloc<T, A> tab;
95
+ size_t numel = 0;
96
+
97
+ static size_t round_capacity(size_t n) {
98
+ if (n == 0) {
99
+ return 0;
100
+ }
101
+ if (n < 8 * A) {
102
+ return 8 * A;
103
+ }
104
+ size_t capacity = 8 * A;
105
+ while (capacity < n) {
106
+ capacity *= 2;
107
+ }
108
+ return capacity;
109
+ }
110
+
111
+ AlignedTable() {}
112
+
113
+ explicit AlignedTable(size_t n):
114
+ tab(round_capacity(n)),
115
+ numel(n)
116
+ { }
117
+
118
+ size_t itemsize() const {return sizeof(T); }
119
+
120
+ void resize(size_t n) {
121
+ tab.resize(round_capacity(n));
122
+ numel = n;
123
+ }
124
+
125
+ void clear() { tab.clear(); }
126
+ size_t size() const {return numel; }
127
+ size_t nbytes() const {return numel * sizeof(T); }
128
+
129
+ T * get() {return tab.get(); }
130
+ const T * get() const {return tab.get(); }
131
+ T * data() {return tab.get(); }
132
+ const T * data() const {return tab.get(); }
133
+ T & operator [] (size_t i) {return tab.ptr[i]; }
134
+ T operator [] (size_t i) const {return tab.ptr[i]; }
135
+
136
+ // assign and copy constructor should work as expected
137
+
138
+ };
139
+
140
+
141
+ } // namespace faiss
@@ -46,8 +46,7 @@ void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
46
46
  for (size_t j = 0; j < nj; j++) {
47
47
  T ip = ip_line [j];
48
48
  if (C::cmp(simi[0], ip)) {
49
- heap_pop<C> (k, simi, idxi);
50
- heap_push<C> (k, simi, idxi, ip, j + j0);
49
+ heap_replace_top<C> (k, simi, idxi, ip, j + j0);
51
50
  }
52
51
  }
53
52
  }
@@ -74,8 +73,7 @@ void HeapArray<C>::addn_with_ids (
74
73
  for (size_t j = 0; j < nj; j++) {
75
74
  T ip = ip_line [j];
76
75
  if (C::cmp(simi[0], ip)) {
77
- heap_pop<C> (k, simi, idxi);
78
- heap_push<C> (k, simi, idxi, ip, id_line [j]);
76
+ heap_replace_top<C> (k, simi, idxi, ip, id_line [j]);
79
77
  }
80
78
  }
81
79
  }
@@ -5,16 +5,18 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
8
 
10
9
  /*
11
- * C++ support for heaps. The set of functions is tailored for
12
- * efficient similarity search.
10
+ * C++ support for heaps. The set of functions is tailored for efficient
11
+ * similarity search.
13
12
  *
14
- * There is no specific object for a heap, and the functions that
15
- * operate on a signle heap are inlined, because heaps are often
16
- * small. More complex functions are implemented in Heaps.cpp
13
+ * There is no specific object for a heap, and the functions that operate on a
14
+ * single heap are inlined, because heaps are often small. More complex
15
+ * functions are implemented in Heaps.cpp
17
16
  *
17
+ * All heap functions rely on a C template class that define the type of the
18
+ * keys and values and their ordering (increasing with CMax and decreasing with
19
+ * Cmin). The C types are defined in ordered_key_value.h
18
20
  */
19
21
 
20
22
 
@@ -31,51 +33,12 @@
31
33
 
32
34
  #include <limits>
33
35
 
36
+ #include <faiss/utils/ordered_key_value.h>
34
37
 
35
38
  namespace faiss {
36
39
 
37
- /*******************************************************************
38
- * C object: uniform handling of min and max heap
39
- *******************************************************************/
40
-
41
- /** The C object gives the type T of the values in the heap, the type
42
- * of the keys, TI and the comparison that is done: > for the minheap
43
- * and < for the maxheap. The neutral value will always be dropped in
44
- * favor of any other value in the heap.
45
- */
46
-
47
- template <typename T_, typename TI_>
48
- struct CMax;
49
-
50
- // traits of minheaps = heaps where the minimum value is stored on top
51
- // useful to find the *max* values of an array
52
- template <typename T_, typename TI_>
53
- struct CMin {
54
- typedef T_ T;
55
- typedef TI_ TI;
56
- typedef CMax<T_, TI_> Crev;
57
- inline static bool cmp (T a, T b) {
58
- return a < b;
59
- }
60
- inline static T neutral () {
61
- return std::numeric_limits<T>::lowest();
62
- }
63
- };
64
40
 
65
41
 
66
- template <typename T_, typename TI_>
67
- struct CMax {
68
- typedef T_ T;
69
- typedef TI_ TI;
70
- typedef CMin<T_, TI_> Crev;
71
- inline static bool cmp (T a, T b) {
72
- return a > b;
73
- }
74
- inline static T neutral () {
75
- return std::numeric_limits<T>::max();
76
- }
77
- };
78
-
79
42
 
80
43
  /*******************************************************************
81
44
  * Basic heap ops: push and pop
@@ -142,6 +105,43 @@ void heap_push (size_t k,
142
105
 
143
106
 
144
107
 
108
+ /** Replace the top element from the heap defined by bh_val[0..k-1] and
109
+ * bh_ids[0..k-1].
110
+ */
111
+ template <class C> inline
112
+ void heap_replace_top (size_t k,
113
+ typename C::T * bh_val, typename C::TI * bh_ids,
114
+ typename C::T val, typename C::TI ids)
115
+ {
116
+ bh_val--; /* Use 1-based indexing for easier node->child translation */
117
+ bh_ids--;
118
+ size_t i = 1, i1, i2;
119
+ while (1) {
120
+ i1 = i << 1;
121
+ i2 = i1 + 1;
122
+ if (i1 > k)
123
+ break;
124
+ if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
125
+ if (C::cmp(val, bh_val[i1]))
126
+ break;
127
+ bh_val[i] = bh_val[i1];
128
+ bh_ids[i] = bh_ids[i1];
129
+ i = i1;
130
+ }
131
+ else {
132
+ if (C::cmp(val, bh_val[i2]))
133
+ break;
134
+ bh_val[i] = bh_val[i2];
135
+ bh_ids[i] = bh_ids[i2];
136
+ i = i2;
137
+ }
138
+ }
139
+ bh_val[i] = val;
140
+ bh_ids[i] = ids;
141
+ }
142
+
143
+
144
+
145
145
  /* Partial instanciation for heaps with TI = int64_t */
146
146
 
147
147
  template <typename T> inline
@@ -158,6 +158,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
158
158
  }
159
159
 
160
160
 
161
+ template <typename T> inline
162
+ void minheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
163
+ {
164
+ heap_replace_top<CMin<T, int64_t> > (k, bh_val, bh_ids, val, ids);
165
+ }
166
+
167
+
161
168
  template <typename T> inline
162
169
  void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
163
170
  {
@@ -172,6 +179,12 @@ void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
172
179
  }
173
180
 
174
181
 
182
+ template <typename T> inline
183
+ void maxheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
184
+ {
185
+ heap_replace_top<CMax<T, int64_t> > (k, bh_val, bh_ids, val, ids);
186
+ }
187
+
175
188
 
176
189
  /*******************************************************************
177
190
  * Heap initialization
@@ -249,15 +262,13 @@ void heap_addn (size_t k,
249
262
  if (ids)
250
263
  for (i = 0; i < n; i++) {
251
264
  if (C::cmp (bh_val[0], x[i])) {
252
- heap_pop<C> (k, bh_val, bh_ids);
253
- heap_push<C> (k, bh_val, bh_ids, x[i], ids[i]);
265
+ heap_replace_top<C> (k, bh_val, bh_ids, x[i], ids[i]);
254
266
  }
255
267
  }
256
268
  else
257
269
  for (i = 0; i < n; i++) {
258
270
  if (C::cmp (bh_val[0], x[i])) {
259
- heap_pop<C> (k, bh_val, bh_ids);
260
- heap_push<C> (k, bh_val, bh_ids, x[i], i);
271
+ heap_replace_top<C> (k, bh_val, bh_ids, x[i], i);
261
272
  }
262
273
  }
263
274
  }
@@ -19,6 +19,7 @@
19
19
 
20
20
  #include <faiss/impl/AuxIndexStructures.h>
21
21
  #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/ResultHandler.h>
22
23
 
23
24
 
24
25
 
@@ -36,14 +37,6 @@ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
36
37
  FINTEGER *lda, const float *b, FINTEGER *
37
38
  ldb, float *beta, float *c, FINTEGER *ldc);
38
39
 
39
- /* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
40
-
41
- int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
42
- float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
43
-
44
- int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
45
- const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
46
- float *beta, float *y, FINTEGER *incy);
47
40
 
48
41
  }
49
42
 
@@ -58,34 +51,6 @@ namespace faiss {
58
51
 
59
52
 
60
53
 
61
- /* Compute the inner product between a vector x and
62
- a set of ny vectors y.
63
- These functions are not intended to replace BLAS matrix-matrix, as they
64
- would be significantly less efficient in this case. */
65
- void fvec_inner_products_ny (float * ip,
66
- const float * x,
67
- const float * y,
68
- size_t d, size_t ny)
69
- {
70
- // Not sure which one is fastest
71
- #if 0
72
- {
73
- FINTEGER di = d;
74
- FINTEGER nyi = ny;
75
- float one = 1.0, zero = 0.0;
76
- FINTEGER onei = 1;
77
- sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
78
- }
79
- #endif
80
- for (size_t i = 0; i < ny; i++) {
81
- ip[i] = fvec_inner_product (x, y, d);
82
- y += d;
83
- }
84
- }
85
-
86
-
87
-
88
-
89
54
 
90
55
  /* Compute the L2 norm of a set of nx vectors */
91
56
  void fvec_norms_L2 (float * __restrict nr,
@@ -142,109 +107,112 @@ void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
142
107
  * KNN functions
143
108
  ***************************************************************************/
144
109
 
110
+ namespace {
111
+
145
112
 
146
113
 
147
114
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
148
- static void knn_inner_product_sse (const float * x,
149
- const float * y,
150
- size_t d, size_t nx, size_t ny,
151
- float_minheap_array_t * res)
115
+ template<class ResultHandler>
116
+ void exhaustive_inner_product_seq (
117
+ const float * x,
118
+ const float * y,
119
+ size_t d, size_t nx, size_t ny,
120
+ ResultHandler &res)
152
121
  {
153
- size_t k = res->k;
154
122
  size_t check_period = InterruptCallback::get_period_hint (ny * d);
155
123
 
156
124
  check_period *= omp_get_max_threads();
157
125
 
126
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
127
+
158
128
  for (size_t i0 = 0; i0 < nx; i0 += check_period) {
159
129
  size_t i1 = std::min(i0 + check_period, nx);
160
130
 
161
- #pragma omp parallel for
162
- for (int64_t i = i0; i < i1; i++) {
163
- const float * x_i = x + i * d;
164
- const float * y_j = y;
165
-
166
- float * __restrict simi = res->get_val(i);
167
- int64_t * __restrict idxi = res->get_ids (i);
168
-
169
- minheap_heapify (k, simi, idxi);
131
+ #pragma omp parallel
132
+ {
133
+ SingleResultHandler resi(res);
134
+ #pragma omp for
135
+ for (int64_t i = i0; i < i1; i++) {
136
+ const float * x_i = x + i * d;
137
+ const float * y_j = y;
170
138
 
171
- for (size_t j = 0; j < ny; j++) {
172
- float ip = fvec_inner_product (x_i, y_j, d);
139
+ resi.begin(i);
173
140
 
174
- if (ip > simi[0]) {
175
- minheap_pop (k, simi, idxi);
176
- minheap_push (k, simi, idxi, ip, j);
141
+ for (size_t j = 0; j < ny; j++) {
142
+ float ip = fvec_inner_product (x_i, y_j, d);
143
+ resi.add_result(ip, j);
144
+ y_j += d;
177
145
  }
178
- y_j += d;
146
+ resi.end();
179
147
  }
180
- minheap_reorder (k, simi, idxi);
181
148
  }
182
149
  InterruptCallback::check ();
183
150
  }
184
151
 
185
152
  }
186
153
 
187
- static void knn_L2sqr_sse (
154
+ template<class ResultHandler>
155
+ void exhaustive_L2sqr_seq (
188
156
  const float * x,
189
157
  const float * y,
190
158
  size_t d, size_t nx, size_t ny,
191
- float_maxheap_array_t * res)
159
+ ResultHandler & res)
192
160
  {
193
- size_t k = res->k;
194
161
 
195
162
  size_t check_period = InterruptCallback::get_period_hint (ny * d);
196
163
  check_period *= omp_get_max_threads();
164
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
197
165
 
198
166
  for (size_t i0 = 0; i0 < nx; i0 += check_period) {
199
167
  size_t i1 = std::min(i0 + check_period, nx);
200
168
 
201
- #pragma omp parallel for
202
- for (int64_t i = i0; i < i1; i++) {
203
- const float * x_i = x + i * d;
204
- const float * y_j = y;
205
- size_t j;
206
- float * simi = res->get_val(i);
207
- int64_t * idxi = res->get_ids (i);
208
-
209
- maxheap_heapify (k, simi, idxi);
210
- for (j = 0; j < ny; j++) {
211
- float disij = fvec_L2sqr (x_i, y_j, d);
212
-
213
- if (disij < simi[0]) {
214
- maxheap_pop (k, simi, idxi);
215
- maxheap_push (k, simi, idxi, disij, j);
169
+ #pragma omp parallel
170
+ {
171
+ SingleResultHandler resi(res);
172
+ #pragma omp for
173
+ for (int64_t i = i0; i < i1; i++) {
174
+ const float * x_i = x + i * d;
175
+ const float * y_j = y;
176
+ resi.begin(i);
177
+ for (size_t j = 0; j < ny; j++) {
178
+ float disij = fvec_L2sqr (x_i, y_j, d);
179
+ resi.add_result(disij, j);
180
+ y_j += d;
216
181
  }
217
- y_j += d;
182
+ resi.end();
218
183
  }
219
- maxheap_reorder (k, simi, idxi);
220
184
  }
221
185
  InterruptCallback::check ();
222
186
  }
223
187
 
224
- }
188
+ };
189
+
190
+
191
+
225
192
 
226
193
 
227
194
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
228
- static void knn_inner_product_blas (
195
+ template<class ResultHandler>
196
+ void exhaustive_inner_product_blas (
229
197
  const float * x,
230
198
  const float * y,
231
199
  size_t d, size_t nx, size_t ny,
232
- float_minheap_array_t * res)
200
+ ResultHandler & res)
233
201
  {
234
- res->heapify ();
235
-
236
202
  // BLAS does not like empty matrices
237
203
  if (nx == 0 || ny == 0) return;
238
204
 
239
205
  /* block sizes */
240
- const size_t bs_x = 4096, bs_y = 1024;
241
- // const size_t bs_x = 16, bs_y = 16;
206
+ const size_t bs_x = distance_compute_blas_query_bs;
207
+ const size_t bs_y = distance_compute_blas_database_bs;
242
208
  std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
243
209
 
244
210
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
245
211
  size_t i1 = i0 + bs_x;
246
212
  if(i1 > nx) i1 = nx;
247
213
 
214
+ res.begin_multiple(i0, i1);
215
+
248
216
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
249
217
  size_t j1 = j0 + bs_y;
250
218
  if (j1 > ny) j1 = ny;
@@ -258,46 +226,54 @@ static void knn_inner_product_blas (
258
226
  ip_block.get(), &nyi);
259
227
  }
260
228
 
261
- /* collect maxima */
262
- res->addn (j1 - j0, ip_block.get(), j0, i0, i1 - i0);
229
+ res.add_results(j0, j1, ip_block.get());
230
+
263
231
  }
232
+ res.end_multiple();
264
233
  InterruptCallback::check ();
234
+
265
235
  }
266
- res->reorder ();
267
236
  }
268
237
 
238
+
239
+
240
+
269
241
  // distance correction is an operator that can be applied to transform
270
242
  // the distances
271
- template<class DistanceCorrection>
272
- static void knn_L2sqr_blas (const float * x,
243
+ template<class ResultHandler>
244
+ void exhaustive_L2sqr_blas (
245
+ const float * x,
273
246
  const float * y,
274
247
  size_t d, size_t nx, size_t ny,
275
- float_maxheap_array_t * res,
276
- const DistanceCorrection &corr)
248
+ ResultHandler & res,
249
+ const float *y_norms = nullptr)
277
250
  {
278
- res->heapify ();
279
-
280
251
  // BLAS does not like empty matrices
281
252
  if (nx == 0 || ny == 0) return;
282
253
 
283
- size_t k = res->k;
284
-
285
254
  /* block sizes */
286
- const size_t bs_x = 4096, bs_y = 1024;
255
+ const size_t bs_x = distance_compute_blas_query_bs;
256
+ const size_t bs_y = distance_compute_blas_database_bs;
287
257
  // const size_t bs_x = 16, bs_y = 16;
288
- float *ip_block = new float[bs_x * bs_y];
289
- float *x_norms = new float[nx];
290
- float *y_norms = new float[ny];
291
- ScopeDeleter<float> del1(ip_block), del3(x_norms), del2(y_norms);
258
+ std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
259
+ std::unique_ptr<float []> x_norms(new float[nx]);
260
+ std::unique_ptr<float []> del2;
292
261
 
293
- fvec_norms_L2sqr (x_norms, x, d, nx);
294
- fvec_norms_L2sqr (y_norms, y, d, ny);
262
+ fvec_norms_L2sqr (x_norms.get(), x, d, nx);
295
263
 
264
+ if (!y_norms) {
265
+ float *y_norms2 = new float[ny];
266
+ del2.reset(y_norms2);
267
+ fvec_norms_L2sqr (y_norms2, y, d, ny);
268
+ y_norms = y_norms2;
269
+ }
296
270
 
297
271
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
298
272
  size_t i1 = i0 + bs_x;
299
273
  if(i1 > nx) i1 = nx;
300
274
 
275
+ res.begin_multiple(i0, i1);
276
+
301
277
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
302
278
  size_t j1 = j0 + bs_y;
303
279
  if (j1 > ny) j1 = ny;
@@ -308,42 +284,34 @@ static void knn_L2sqr_blas (const float * x,
308
284
  sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
309
285
  y + j0 * d, &di,
310
286
  x + i0 * d, &di, &zero,
311
- ip_block, &nyi);
287
+ ip_block.get(), &nyi);
312
288
  }
313
289
 
314
- /* collect minima */
315
- #pragma omp parallel for
316
290
  for (int64_t i = i0; i < i1; i++) {
317
- float * __restrict simi = res->get_val(i);
318
- int64_t * __restrict idxi = res->get_ids (i);
319
- const float *ip_line = ip_block + (i - i0) * (j1 - j0);
291
+ float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
320
292
 
321
293
  for (size_t j = j0; j < j1; j++) {
322
- float ip = *ip_line++;
294
+ float ip = *ip_line;
323
295
  float dis = x_norms[i] + y_norms[j] - 2 * ip;
324
296
 
325
297
  // negative values can occur for identical vectors
326
298
  // due to roundoff errors
327
299
  if (dis < 0) dis = 0;
328
300
 
329
- dis = corr (dis, i, j);
330
-
331
- if (dis < simi[0]) {
332
- maxheap_pop (k, simi, idxi);
333
- maxheap_push (k, simi, idxi, dis, j);
334
- }
301
+ *ip_line = dis;
302
+ ip_line++;
335
303
  }
336
304
  }
305
+ res.add_results(j0, j1, ip_block.get());
337
306
  }
307
+ res.end_multiple();
338
308
  InterruptCallback::check ();
339
309
  }
340
- res->reorder ();
341
-
342
310
  }
343
311
 
344
312
 
345
313
 
346
-
314
+ } // anonymous namespace
347
315
 
348
316
 
349
317
 
@@ -354,58 +322,103 @@ static void knn_L2sqr_blas (const float * x,
354
322
  *******************************************************/
355
323
 
356
324
  int distance_compute_blas_threshold = 20;
325
+ int distance_compute_blas_query_bs = 4096;
326
+ int distance_compute_blas_database_bs = 1024;
327
+ int distance_compute_min_k_reservoir = 100;
357
328
 
358
329
  void knn_inner_product (const float * x,
359
330
  const float * y,
360
331
  size_t d, size_t nx, size_t ny,
361
- float_minheap_array_t * res)
332
+ float_minheap_array_t * ha)
362
333
  {
363
- if (nx < distance_compute_blas_threshold) {
364
- knn_inner_product_sse (x, y, d, nx, ny, res);
334
+ if (ha->k < distance_compute_min_k_reservoir) {
335
+ HeapResultHandler<CMin<float, int64_t>> res(
336
+ ha->nh, ha->val, ha->ids, ha->k);
337
+ if (nx < distance_compute_blas_threshold) {
338
+ exhaustive_inner_product_seq (x, y, d, nx, ny, res);
339
+ } else {
340
+ exhaustive_inner_product_blas (x, y, d, nx, ny, res);
341
+ }
365
342
  } else {
366
- knn_inner_product_blas (x, y, d, nx, ny, res);
343
+ ReservoirResultHandler<CMin<float, int64_t>> res(
344
+ ha->nh, ha->val, ha->ids, ha->k);
345
+ if (nx < distance_compute_blas_threshold) {
346
+ exhaustive_inner_product_seq (x, y, d, nx, ny, res);
347
+ } else {
348
+ exhaustive_inner_product_blas (x, y, d, nx, ny, res);
349
+ }
367
350
  }
368
351
  }
369
352
 
370
353
 
371
354
 
372
- struct NopDistanceCorrection {
373
- float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
374
- return dis;
355
+
356
+ void knn_L2sqr (
357
+ const float * x,
358
+ const float * y,
359
+ size_t d, size_t nx, size_t ny,
360
+ float_maxheap_array_t * ha,
361
+ const float *y_norm2
362
+ ) {
363
+
364
+ if (ha->k < distance_compute_min_k_reservoir) {
365
+ HeapResultHandler<CMax<float, int64_t>> res(
366
+ ha->nh, ha->val, ha->ids, ha->k);
367
+
368
+ if (nx < distance_compute_blas_threshold) {
369
+ exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
370
+ } else {
371
+ exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
372
+ }
373
+ } else {
374
+ ReservoirResultHandler<CMax<float, int64_t>> res(
375
+ ha->nh, ha->val, ha->ids, ha->k);
376
+ if (nx < distance_compute_blas_threshold) {
377
+ exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
378
+ } else {
379
+ exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
380
+ }
375
381
  }
376
- };
382
+ }
377
383
 
378
- void knn_L2sqr (const float * x,
379
- const float * y,
380
- size_t d, size_t nx, size_t ny,
381
- float_maxheap_array_t * res)
384
+
385
+ /***************************************************************************
386
+ * Range search
387
+ ***************************************************************************/
388
+
389
+
390
+
391
+
392
+ void range_search_L2sqr (
393
+ const float * x,
394
+ const float * y,
395
+ size_t d, size_t nx, size_t ny,
396
+ float radius,
397
+ RangeSearchResult *res)
382
398
  {
399
+ RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
383
400
  if (nx < distance_compute_blas_threshold) {
384
- knn_L2sqr_sse (x, y, d, nx, ny, res);
401
+ exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
385
402
  } else {
386
- NopDistanceCorrection nop;
387
- knn_L2sqr_blas (x, y, d, nx, ny, res, nop);
403
+ exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
388
404
  }
389
405
  }
390
406
 
391
- struct BaseShiftDistanceCorrection {
392
- const float *base_shift;
393
- float operator()(float dis, size_t /*qno*/, size_t bno) const {
394
- return dis - base_shift[bno];
395
- }
396
- };
397
-
398
- void knn_L2sqr_base_shift (
399
- const float * x,
400
- const float * y,
401
- size_t d, size_t nx, size_t ny,
402
- float_maxheap_array_t * res,
403
- const float *base_shift)
407
+ void range_search_inner_product (
408
+ const float * x,
409
+ const float * y,
410
+ size_t d, size_t nx, size_t ny,
411
+ float radius,
412
+ RangeSearchResult *res)
404
413
  {
405
- BaseShiftDistanceCorrection corr = {base_shift};
406
- knn_L2sqr_blas (x, y, d, nx, ny, res, corr);
407
- }
408
414
 
415
+ RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
416
+ if (nx < distance_compute_blas_threshold) {
417
+ exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
418
+ } else {
419
+ exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
420
+ }
421
+ }
409
422
 
410
423
 
411
424
  /***************************************************************************
@@ -509,8 +522,7 @@ void knn_inner_products_by_idx (const float * x,
509
522
  float ip = fvec_inner_product (x_, y + d * idsi[j], d);
510
523
 
511
524
  if (ip > simi[0]) {
512
- minheap_pop (k, simi, idxi);
513
- minheap_push (k, simi, idxi, ip, idsi[j]);
525
+ minheap_replace_top (k, simi, idxi, ip, idsi[j]);
514
526
  }
515
527
  }
516
528
  minheap_reorder (k, simi, idxi);
@@ -537,8 +549,7 @@ void knn_L2sqr_by_idx (const float * x,
537
549
  float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
538
550
 
539
551
  if (disij < simi[0]) {
540
- maxheap_pop (k, simi, idxi);
541
- maxheap_push (k, simi, idxi, disij, idsi[j]);
552
+ maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
542
553
  }
543
554
  }
544
555
  maxheap_reorder (res->k, simi, idxi);
@@ -550,172 +561,6 @@ void knn_L2sqr_by_idx (const float * x,
550
561
 
551
562
 
552
563
 
553
- /***************************************************************************
554
- * Range search
555
- ***************************************************************************/
556
-
557
- /** Find the nearest neighbors for nx queries in a set of ny vectors
558
- * compute_l2 = compute pairwise squared L2 distance rather than inner prod
559
- */
560
- template <bool compute_l2>
561
- static void range_search_blas (
562
- const float * x,
563
- const float * y,
564
- size_t d, size_t nx, size_t ny,
565
- float radius,
566
- RangeSearchResult *result)
567
- {
568
-
569
- // BLAS does not like empty matrices
570
- if (nx == 0 || ny == 0) return;
571
-
572
- /* block sizes */
573
- const size_t bs_x = 4096, bs_y = 1024;
574
- // const size_t bs_x = 16, bs_y = 16;
575
- float *ip_block = new float[bs_x * bs_y];
576
- ScopeDeleter<float> del0(ip_block);
577
-
578
- float *x_norms = nullptr, *y_norms = nullptr;
579
- ScopeDeleter<float> del1, del2;
580
- if (compute_l2) {
581
- x_norms = new float[nx];
582
- del1.set (x_norms);
583
- fvec_norms_L2sqr (x_norms, x, d, nx);
584
-
585
- y_norms = new float[ny];
586
- del2.set (y_norms);
587
- fvec_norms_L2sqr (y_norms, y, d, ny);
588
- }
589
-
590
- std::vector <RangeSearchPartialResult *> partial_results;
591
-
592
- for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
593
- size_t j1 = j0 + bs_y;
594
- if (j1 > ny) j1 = ny;
595
- RangeSearchPartialResult * pres = new RangeSearchPartialResult (result);
596
- partial_results.push_back (pres);
597
-
598
- for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
599
- size_t i1 = i0 + bs_x;
600
- if(i1 > nx) i1 = nx;
601
-
602
- /* compute the actual dot products */
603
- {
604
- float one = 1, zero = 0;
605
- FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
606
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
607
- y + j0 * d, &di,
608
- x + i0 * d, &di, &zero,
609
- ip_block, &nyi);
610
- }
611
-
612
-
613
- for (size_t i = i0; i < i1; i++) {
614
- const float *ip_line = ip_block + (i - i0) * (j1 - j0);
615
-
616
- RangeQueryResult & qres = pres->new_result (i);
617
-
618
- for (size_t j = j0; j < j1; j++) {
619
- float ip = *ip_line++;
620
- if (compute_l2) {
621
- float dis = x_norms[i] + y_norms[j] - 2 * ip;
622
- if (dis < radius) {
623
- qres.add (dis, j);
624
- }
625
- } else {
626
- if (ip > radius) {
627
- qres.add (ip, j);
628
- }
629
- }
630
- }
631
- }
632
- }
633
- InterruptCallback::check ();
634
- }
635
-
636
- RangeSearchPartialResult::merge (partial_results);
637
- }
638
-
639
-
640
- template <bool compute_l2>
641
- static void range_search_sse (const float * x,
642
- const float * y,
643
- size_t d, size_t nx, size_t ny,
644
- float radius,
645
- RangeSearchResult *res)
646
- {
647
-
648
- #pragma omp parallel
649
- {
650
- RangeSearchPartialResult pres (res);
651
-
652
- #pragma omp for
653
- for (int64_t i = 0; i < nx; i++) {
654
- const float * x_ = x + i * d;
655
- const float * y_ = y;
656
- size_t j;
657
-
658
- RangeQueryResult & qres = pres.new_result (i);
659
-
660
- for (j = 0; j < ny; j++) {
661
- if (compute_l2) {
662
- float disij = fvec_L2sqr (x_, y_, d);
663
- if (disij < radius) {
664
- qres.add (disij, j);
665
- }
666
- } else {
667
- float ip = fvec_inner_product (x_, y_, d);
668
- if (ip > radius) {
669
- qres.add (ip, j);
670
- }
671
- }
672
- y_ += d;
673
- }
674
-
675
- }
676
- pres.finalize ();
677
- }
678
-
679
- // check just at the end because the use case is typically just
680
- // when the nb of queries is low.
681
- InterruptCallback::check();
682
- }
683
-
684
-
685
-
686
-
687
-
688
- void range_search_L2sqr (
689
- const float * x,
690
- const float * y,
691
- size_t d, size_t nx, size_t ny,
692
- float radius,
693
- RangeSearchResult *res)
694
- {
695
-
696
- if (nx < distance_compute_blas_threshold) {
697
- range_search_sse<true> (x, y, d, nx, ny, radius, res);
698
- } else {
699
- range_search_blas<true> (x, y, d, nx, ny, radius, res);
700
- }
701
- }
702
-
703
- void range_search_inner_product (
704
- const float * x,
705
- const float * y,
706
- size_t d, size_t nx, size_t ny,
707
- float radius,
708
- RangeSearchResult *res)
709
- {
710
-
711
- if (nx < distance_compute_blas_threshold) {
712
- range_search_sse<false> (x, y, d, nx, ny, radius, res);
713
- } else {
714
- range_search_blas<false> (x, y, d, nx, ny, radius, res);
715
- }
716
- }
717
-
718
-
719
564
  void pairwise_L2sqr (int64_t d,
720
565
  int64_t nq, const float *xq,
721
566
  int64_t nb, const float *xb,