faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,117 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ /** In this file are the implementations of extra metrics beyond L2
9
+ * and inner product */
10
+
11
+ #include <faiss/utils/distances.h>
12
+ #include <type_traits>
13
+
14
+ namespace faiss {
15
+
16
+ template <MetricType mt>
17
+ struct VectorDistance {
18
+ size_t d;
19
+ float metric_arg;
20
+
21
+ inline float operator()(const float* x, const float* y) const;
22
+
23
+ // heap template to use for this type of metric
24
+ using C = typename std::conditional<
25
+ mt == METRIC_INNER_PRODUCT,
26
+ CMin<float, int64_t>,
27
+ CMax<float, int64_t>>::type;
28
+ };
29
+
30
+ template <>
31
+ inline float VectorDistance<METRIC_L2>::operator()(
32
+ const float* x,
33
+ const float* y) const {
34
+ return fvec_L2sqr(x, y, d);
35
+ }
36
+
37
+ template <>
38
+ inline float VectorDistance<METRIC_INNER_PRODUCT>::operator()(
39
+ const float* x,
40
+ const float* y) const {
41
+ return fvec_inner_product(x, y, d);
42
+ }
43
+
44
+ template <>
45
+ inline float VectorDistance<METRIC_L1>::operator()(
46
+ const float* x,
47
+ const float* y) const {
48
+ return fvec_L1(x, y, d);
49
+ }
50
+
51
+ template <>
52
+ inline float VectorDistance<METRIC_Linf>::operator()(
53
+ const float* x,
54
+ const float* y) const {
55
+ return fvec_Linf(x, y, d);
56
+ /*
57
+ float vmax = 0;
58
+ for (size_t i = 0; i < d; i++) {
59
+ float diff = fabs (x[i] - y[i]);
60
+ if (diff > vmax) vmax = diff;
61
+ }
62
+ return vmax;*/
63
+ }
64
+
65
+ template <>
66
+ inline float VectorDistance<METRIC_Lp>::operator()(
67
+ const float* x,
68
+ const float* y) const {
69
+ float accu = 0;
70
+ for (size_t i = 0; i < d; i++) {
71
+ float diff = fabs(x[i] - y[i]);
72
+ accu += powf(diff, metric_arg);
73
+ }
74
+ return accu;
75
+ }
76
+
77
+ template <>
78
+ inline float VectorDistance<METRIC_Canberra>::operator()(
79
+ const float* x,
80
+ const float* y) const {
81
+ float accu = 0;
82
+ for (size_t i = 0; i < d; i++) {
83
+ float xi = x[i], yi = y[i];
84
+ accu += fabs(xi - yi) / (fabs(xi) + fabs(yi));
85
+ }
86
+ return accu;
87
+ }
88
+
89
+ template <>
90
+ inline float VectorDistance<METRIC_BrayCurtis>::operator()(
91
+ const float* x,
92
+ const float* y) const {
93
+ float accu_num = 0, accu_den = 0;
94
+ for (size_t i = 0; i < d; i++) {
95
+ float xi = x[i], yi = y[i];
96
+ accu_num += fabs(xi - yi);
97
+ accu_den += fabs(xi + yi);
98
+ }
99
+ return accu_num / accu_den;
100
+ }
101
+
102
+ template <>
103
+ inline float VectorDistance<METRIC_JensenShannon>::operator()(
104
+ const float* x,
105
+ const float* y) const {
106
+ float accu = 0;
107
+ for (size_t i = 0; i < d; i++) {
108
+ float xi = x[i], yi = y[i];
109
+ float mi = 0.5 * (xi + yi);
110
+ float kl1 = -xi * log(mi / xi);
111
+ float kl2 = -yi * log(mi / yi);
112
+ accu += kl1 + kl2;
113
+ }
114
+ return 0.5 * accu;
115
+ }
116
+
117
+ } // namespace faiss
@@ -7,16 +7,15 @@
7
7
 
8
8
  // -*- c++ -*-
9
9
 
10
- #include <faiss/utils/distances.h>
10
+ #include <faiss/utils/extra_distances.h>
11
11
 
12
+ #include <omp.h>
12
13
  #include <algorithm>
13
14
  #include <cmath>
14
- #include <omp.h>
15
-
16
15
 
17
- #include <faiss/utils/utils.h>
18
- #include <faiss/impl/FaissAssert.h>
19
16
  #include <faiss/impl/AuxIndexStructures.h>
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/utils/utils.h>
20
19
 
21
20
  namespace faiss {
22
21
 
@@ -24,140 +23,43 @@ namespace faiss {
24
23
  * Distance functions (other than L2 and IP)
25
24
  ***************************************************************************/
26
25
 
27
- struct VectorDistanceL2 {
28
- size_t d;
29
-
30
- float operator () (const float *x, const float *y) const {
31
- return fvec_L2sqr (x, y, d);
32
- }
33
- };
34
-
35
- struct VectorDistanceL1 {
36
- size_t d;
37
-
38
- float operator () (const float *x, const float *y) const {
39
- return fvec_L1 (x, y, d);
40
- }
41
- };
42
-
43
- struct VectorDistanceLinf {
44
- size_t d;
45
-
46
- float operator () (const float *x, const float *y) const {
47
- return fvec_Linf (x, y, d);
48
- /*
49
- float vmax = 0;
50
- for (size_t i = 0; i < d; i++) {
51
- float diff = fabs (x[i] - y[i]);
52
- if (diff > vmax) vmax = diff;
53
- }
54
- return vmax;*/
55
- }
56
- };
57
-
58
- struct VectorDistanceLp {
59
- size_t d;
60
- const float p;
61
-
62
- float operator () (const float *x, const float *y) const {
63
- float accu = 0;
64
- for (size_t i = 0; i < d; i++) {
65
- float diff = fabs (x[i] - y[i]);
66
- accu += powf (diff, p);
67
- }
68
- return accu;
69
- }
70
- };
71
-
72
- struct VectorDistanceCanberra {
73
- size_t d;
74
-
75
- float operator () (const float *x, const float *y) const {
76
- float accu = 0;
77
- for (size_t i = 0; i < d; i++) {
78
- float xi = x[i], yi = y[i];
79
- accu += fabs (xi - yi) / (fabs(xi) + fabs(yi));
80
- }
81
- return accu;
82
- }
83
- };
84
-
85
- struct VectorDistanceBrayCurtis {
86
- size_t d;
87
-
88
- float operator () (const float *x, const float *y) const {
89
- float accu_num = 0, accu_den = 0;
90
- for (size_t i = 0; i < d; i++) {
91
- float xi = x[i], yi = y[i];
92
- accu_num += fabs (xi - yi);
93
- accu_den += fabs (xi + yi);
94
- }
95
- return accu_num / accu_den;
96
- }
97
- };
98
-
99
- struct VectorDistanceJensenShannon {
100
- size_t d;
101
-
102
- float operator () (const float *x, const float *y) const {
103
- float accu = 0;
104
-
105
- for (size_t i = 0; i < d; i++) {
106
- float xi = x[i], yi = y[i];
107
- float mi = 0.5 * (xi + yi);
108
- float kl1 = - xi * log(mi / xi);
109
- float kl2 = - yi * log(mi / yi);
110
- accu += kl1 + kl2;
111
- }
112
- return 0.5 * accu;
113
- }
114
- };
115
-
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-
124
-
125
26
  namespace {
126
27
 
127
- template<class VD>
128
- void pairwise_extra_distances_template (
129
- VD vd,
130
- int64_t nq, const float *xq,
131
- int64_t nb, const float *xb,
132
- float *dis,
133
- int64_t ldq, int64_t ldb, int64_t ldd)
134
- {
135
-
136
- #pragma omp parallel for if(nq > 10)
28
+ template <class VD>
29
+ void pairwise_extra_distances_template(
30
+ VD vd,
31
+ int64_t nq,
32
+ const float* xq,
33
+ int64_t nb,
34
+ const float* xb,
35
+ float* dis,
36
+ int64_t ldq,
37
+ int64_t ldb,
38
+ int64_t ldd) {
39
+ #pragma omp parallel for if (nq > 10)
137
40
  for (int64_t i = 0; i < nq; i++) {
138
- const float *xqi = xq + i * ldq;
139
- const float *xbj = xb;
140
- float *disi = dis + ldd * i;
41
+ const float* xqi = xq + i * ldq;
42
+ const float* xbj = xb;
43
+ float* disi = dis + ldd * i;
141
44
 
142
45
  for (int64_t j = 0; j < nb; j++) {
143
- disi[j] = vd (xqi, xbj);
46
+ disi[j] = vd(xqi, xbj);
144
47
  xbj += ldb;
145
48
  }
146
49
  }
147
50
  }
148
51
 
149
-
150
- template<class VD>
151
- void knn_extra_metrics_template (
52
+ template <class VD>
53
+ void knn_extra_metrics_template(
152
54
  VD vd,
153
- const float * x,
154
- const float * y,
155
- size_t nx, size_t ny,
156
- float_maxheap_array_t * res)
157
- {
55
+ const float* x,
56
+ const float* y,
57
+ size_t nx,
58
+ size_t ny,
59
+ float_maxheap_array_t* res) {
158
60
  size_t k = res->k;
159
61
  size_t d = vd.d;
160
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
62
+ size_t check_period = InterruptCallback::get_period_hint(ny * d);
161
63
  check_period *= omp_get_max_threads();
162
64
 
163
65
  for (size_t i0 = 0; i0 < nx; i0 += check_period) {
@@ -165,90 +67,84 @@ void knn_extra_metrics_template (
165
67
 
166
68
  #pragma omp parallel for
167
69
  for (int64_t i = i0; i < i1; i++) {
168
- const float * x_i = x + i * d;
169
- const float * y_j = y;
70
+ const float* x_i = x + i * d;
71
+ const float* y_j = y;
170
72
  size_t j;
171
- float * simi = res->get_val(i);
172
- int64_t * idxi = res->get_ids (i);
73
+ float* simi = res->get_val(i);
74
+ int64_t* idxi = res->get_ids(i);
173
75
 
174
- maxheap_heapify (k, simi, idxi);
76
+ maxheap_heapify(k, simi, idxi);
175
77
  for (j = 0; j < ny; j++) {
176
- float disij = vd (x_i, y_j);
78
+ float disij = vd(x_i, y_j);
177
79
 
178
80
  if (disij < simi[0]) {
179
- maxheap_replace_top (k, simi, idxi, disij, j);
81
+ maxheap_replace_top(k, simi, idxi, disij, j);
180
82
  }
181
83
  y_j += d;
182
84
  }
183
- maxheap_reorder (k, simi, idxi);
85
+ maxheap_reorder(k, simi, idxi);
184
86
  }
185
- InterruptCallback::check ();
87
+ InterruptCallback::check();
186
88
  }
187
-
188
89
  }
189
90
 
190
-
191
- template<class VD>
91
+ template <class VD>
192
92
  struct ExtraDistanceComputer : DistanceComputer {
193
93
  VD vd;
194
94
  Index::idx_t nb;
195
- const float *q;
196
- const float *b;
95
+ const float* q;
96
+ const float* b;
197
97
 
198
- float operator () (idx_t i) override {
199
- return vd (q, b + i * vd.d);
98
+ float operator()(idx_t i) override {
99
+ return vd(q, b + i * vd.d);
200
100
  }
201
101
 
202
102
  float symmetric_dis(idx_t i, idx_t j) override {
203
- return vd (b + j * vd.d, b + i * vd.d);
103
+ return vd(b + j * vd.d, b + i * vd.d);
204
104
  }
205
105
 
206
- ExtraDistanceComputer(const VD & vd, const float *xb,
207
- size_t nb, const float *q = nullptr)
208
- : vd(vd), nb(nb), q(q), b(xb) {}
106
+ ExtraDistanceComputer(
107
+ const VD& vd,
108
+ const float* xb,
109
+ size_t nb,
110
+ const float* q = nullptr)
111
+ : vd(vd), nb(nb), q(q), b(xb) {}
209
112
 
210
- void set_query(const float *x) override {
113
+ void set_query(const float* x) override {
211
114
  q = x;
212
115
  }
213
116
  };
214
117
 
215
-
216
-
217
-
218
-
219
-
220
-
221
-
222
-
223
-
224
-
225
-
226
-
227
-
228
-
229
-
230
118
  } // anonymous namespace
231
119
 
232
- void pairwise_extra_distances (
233
- int64_t d,
234
- int64_t nq, const float *xq,
235
- int64_t nb, const float *xb,
236
- MetricType mt, float metric_arg,
237
- float *dis,
238
- int64_t ldq, int64_t ldb, int64_t ldd)
239
- {
240
- if (nq == 0 || nb == 0) return;
241
- if (ldq == -1) ldq = d;
242
- if (ldb == -1) ldb = d;
243
- if (ldd == -1) ldd = nb;
244
-
245
- switch(mt) {
246
- #define HANDLE_VAR(kw) \
247
- case METRIC_ ## kw: { \
248
- VectorDistance ## kw vd = {(size_t)d}; \
249
- pairwise_extra_distances_template (vd, nq, xq, nb, xb, \
250
- dis, ldq, ldb, ldd); \
251
- break; \
120
+ void pairwise_extra_distances(
121
+ int64_t d,
122
+ int64_t nq,
123
+ const float* xq,
124
+ int64_t nb,
125
+ const float* xb,
126
+ MetricType mt,
127
+ float metric_arg,
128
+ float* dis,
129
+ int64_t ldq,
130
+ int64_t ldb,
131
+ int64_t ldd) {
132
+ if (nq == 0 || nb == 0)
133
+ return;
134
+ if (ldq == -1)
135
+ ldq = d;
136
+ if (ldb == -1)
137
+ ldb = d;
138
+ if (ldd == -1)
139
+ ldd = nb;
140
+
141
+ switch (mt) {
142
+ #define HANDLE_VAR(kw) \
143
+ case METRIC_##kw: { \
144
+ VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
145
+ pairwise_extra_distances_template( \
146
+ vd, nq, xq, nb, xb, dis, ldq, ldb, ldd); \
147
+ break; \
252
148
  }
253
149
  HANDLE_VAR(L2);
254
150
  HANDLE_VAR(L1);
@@ -256,33 +152,28 @@ void pairwise_extra_distances (
256
152
  HANDLE_VAR(Canberra);
257
153
  HANDLE_VAR(BrayCurtis);
258
154
  HANDLE_VAR(JensenShannon);
155
+ HANDLE_VAR(Lp);
259
156
  #undef HANDLE_VAR
260
- case METRIC_Lp: {
261
- VectorDistanceLp vd = {(size_t)d, metric_arg};
262
- pairwise_extra_distances_template (vd, nq, xq, nb, xb,
263
- dis, ldq, ldb, ldd);
264
- break;
157
+ default:
158
+ FAISS_THROW_MSG("metric type not implemented");
265
159
  }
266
- default:
267
- FAISS_THROW_MSG ("metric type not implemented");
268
- }
269
-
270
160
  }
271
161
 
272
- void knn_extra_metrics (
273
- const float * x,
274
- const float * y,
275
- size_t d, size_t nx, size_t ny,
276
- MetricType mt, float metric_arg,
277
- float_maxheap_array_t * res)
278
- {
279
-
280
- switch(mt) {
281
- #define HANDLE_VAR(kw) \
282
- case METRIC_ ## kw: { \
283
- VectorDistance ## kw vd = {(size_t)d}; \
284
- knn_extra_metrics_template (vd, x, y, nx, ny, res); \
285
- break; \
162
+ void knn_extra_metrics(
163
+ const float* x,
164
+ const float* y,
165
+ size_t d,
166
+ size_t nx,
167
+ size_t ny,
168
+ MetricType mt,
169
+ float metric_arg,
170
+ float_maxheap_array_t* res) {
171
+ switch (mt) {
172
+ #define HANDLE_VAR(kw) \
173
+ case METRIC_##kw: { \
174
+ VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
175
+ knn_extra_metrics_template(vd, x, y, nx, ny, res); \
176
+ break; \
286
177
  }
287
178
  HANDLE_VAR(L2);
288
179
  HANDLE_VAR(L1);
@@ -290,29 +181,25 @@ void knn_extra_metrics (
290
181
  HANDLE_VAR(Canberra);
291
182
  HANDLE_VAR(BrayCurtis);
292
183
  HANDLE_VAR(JensenShannon);
184
+ HANDLE_VAR(Lp);
293
185
  #undef HANDLE_VAR
294
- case METRIC_Lp: {
295
- VectorDistanceLp vd = {(size_t)d, metric_arg};
296
- knn_extra_metrics_template (vd, x, y, nx, ny, res);
297
- break;
298
- }
299
- default:
300
- FAISS_THROW_MSG ("metric type not implemented");
186
+ default:
187
+ FAISS_THROW_MSG("metric type not implemented");
301
188
  }
302
-
303
189
  }
304
190
 
305
- DistanceComputer *get_extra_distance_computer (
191
+ DistanceComputer* get_extra_distance_computer(
306
192
  size_t d,
307
- MetricType mt, float metric_arg,
308
- size_t nb, const float *xb)
309
- {
310
-
311
- switch(mt) {
312
- #define HANDLE_VAR(kw) \
313
- case METRIC_ ## kw: { \
314
- VectorDistance ## kw vd = {(size_t)d}; \
315
- return new ExtraDistanceComputer<VectorDistance ## kw>(vd, xb, nb); \
193
+ MetricType mt,
194
+ float metric_arg,
195
+ size_t nb,
196
+ const float* xb) {
197
+ switch (mt) {
198
+ #define HANDLE_VAR(kw) \
199
+ case METRIC_##kw: { \
200
+ VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
201
+ return new ExtraDistanceComputer<VectorDistance<METRIC_##kw>>( \
202
+ vd, xb, nb); \
316
203
  }
317
204
  HANDLE_VAR(L2);
318
205
  HANDLE_VAR(L1);
@@ -320,17 +207,11 @@ DistanceComputer *get_extra_distance_computer (
320
207
  HANDLE_VAR(Canberra);
321
208
  HANDLE_VAR(BrayCurtis);
322
209
  HANDLE_VAR(JensenShannon);
210
+ HANDLE_VAR(Lp);
323
211
  #undef HANDLE_VAR
324
- case METRIC_Lp: {
325
- VectorDistanceLp vd = {(size_t)d, metric_arg};
326
- return new ExtraDistanceComputer<VectorDistanceLp> (vd, xb, nb);
327
- break;
328
- }
329
- default:
330
- FAISS_THROW_MSG ("metric type not implemented");
212
+ default:
213
+ FAISS_THROW_MSG("metric type not implemented");
331
214
  }
332
-
333
215
  }
334
216
 
335
-
336
217
  } // namespace faiss
@@ -5,10 +5,7 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
- #ifndef FAISS_distances_h
11
- #define FAISS_distances_h
8
+ #pragma once
12
9
 
13
10
  /** In this file are the implementations of extra metrics beyond L2
14
11
  * and inner product */
@@ -19,36 +16,40 @@
19
16
 
20
17
  #include <faiss/utils/Heap.h>
21
18
 
22
-
23
-
24
19
  namespace faiss {
25
20
 
26
-
27
- void pairwise_extra_distances (
28
- int64_t d,
29
- int64_t nq, const float *xq,
30
- int64_t nb, const float *xb,
31
- MetricType mt, float metric_arg,
32
- float *dis,
33
- int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1);
34
-
35
-
36
- void knn_extra_metrics (
37
- const float * x,
38
- const float * y,
39
- size_t d, size_t nx, size_t ny,
40
- MetricType mt, float metric_arg,
41
- float_maxheap_array_t * res);
42
-
21
+ void pairwise_extra_distances(
22
+ int64_t d,
23
+ int64_t nq,
24
+ const float* xq,
25
+ int64_t nb,
26
+ const float* xb,
27
+ MetricType mt,
28
+ float metric_arg,
29
+ float* dis,
30
+ int64_t ldq = -1,
31
+ int64_t ldb = -1,
32
+ int64_t ldd = -1);
33
+
34
+ void knn_extra_metrics(
35
+ const float* x,
36
+ const float* y,
37
+ size_t d,
38
+ size_t nx,
39
+ size_t ny,
40
+ MetricType mt,
41
+ float metric_arg,
42
+ float_maxheap_array_t* res);
43
43
 
44
44
  /** get a DistanceComputer that refers to this type of distance and
45
45
  * indexes a flat array of size nb */
46
- DistanceComputer *get_extra_distance_computer (
46
+ DistanceComputer* get_extra_distance_computer(
47
47
  size_t d,
48
- MetricType mt, float metric_arg,
49
- size_t nb, const float *xb);
50
-
51
- }
48
+ MetricType mt,
49
+ float metric_arg,
50
+ size_t nb,
51
+ const float* xb);
52
52
 
53
+ } // namespace faiss
53
54
 
54
- #endif
55
+ #include <faiss/utils/extra_distances-inl.h>