faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,117 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ /** In this file are the implementations of extra metrics beyond L2
9
+ * and inner product */
10
+
11
+ #include <faiss/utils/distances.h>
12
+ #include <type_traits>
13
+
14
+ namespace faiss {
15
+
16
+ template <MetricType mt>
17
+ struct VectorDistance {
18
+ size_t d;
19
+ float metric_arg;
20
+
21
+ inline float operator()(const float* x, const float* y) const;
22
+
23
+ // heap template to use for this type of metric
24
+ using C = typename std::conditional<
25
+ mt == METRIC_INNER_PRODUCT,
26
+ CMin<float, int64_t>,
27
+ CMax<float, int64_t>>::type;
28
+ };
29
+
30
+ template <>
31
+ inline float VectorDistance<METRIC_L2>::operator()(
32
+ const float* x,
33
+ const float* y) const {
34
+ return fvec_L2sqr(x, y, d);
35
+ }
36
+
37
+ template <>
38
+ inline float VectorDistance<METRIC_INNER_PRODUCT>::operator()(
39
+ const float* x,
40
+ const float* y) const {
41
+ return fvec_inner_product(x, y, d);
42
+ }
43
+
44
+ template <>
45
+ inline float VectorDistance<METRIC_L1>::operator()(
46
+ const float* x,
47
+ const float* y) const {
48
+ return fvec_L1(x, y, d);
49
+ }
50
+
51
+ template <>
52
+ inline float VectorDistance<METRIC_Linf>::operator()(
53
+ const float* x,
54
+ const float* y) const {
55
+ return fvec_Linf(x, y, d);
56
+ /*
57
+ float vmax = 0;
58
+ for (size_t i = 0; i < d; i++) {
59
+ float diff = fabs (x[i] - y[i]);
60
+ if (diff > vmax) vmax = diff;
61
+ }
62
+ return vmax;*/
63
+ }
64
+
65
+ template <>
66
+ inline float VectorDistance<METRIC_Lp>::operator()(
67
+ const float* x,
68
+ const float* y) const {
69
+ float accu = 0;
70
+ for (size_t i = 0; i < d; i++) {
71
+ float diff = fabs(x[i] - y[i]);
72
+ accu += powf(diff, metric_arg);
73
+ }
74
+ return accu;
75
+ }
76
+
77
+ template <>
78
+ inline float VectorDistance<METRIC_Canberra>::operator()(
79
+ const float* x,
80
+ const float* y) const {
81
+ float accu = 0;
82
+ for (size_t i = 0; i < d; i++) {
83
+ float xi = x[i], yi = y[i];
84
+ accu += fabs(xi - yi) / (fabs(xi) + fabs(yi));
85
+ }
86
+ return accu;
87
+ }
88
+
89
+ template <>
90
+ inline float VectorDistance<METRIC_BrayCurtis>::operator()(
91
+ const float* x,
92
+ const float* y) const {
93
+ float accu_num = 0, accu_den = 0;
94
+ for (size_t i = 0; i < d; i++) {
95
+ float xi = x[i], yi = y[i];
96
+ accu_num += fabs(xi - yi);
97
+ accu_den += fabs(xi + yi);
98
+ }
99
+ return accu_num / accu_den;
100
+ }
101
+
102
+ template <>
103
+ inline float VectorDistance<METRIC_JensenShannon>::operator()(
104
+ const float* x,
105
+ const float* y) const {
106
+ float accu = 0;
107
+ for (size_t i = 0; i < d; i++) {
108
+ float xi = x[i], yi = y[i];
109
+ float mi = 0.5 * (xi + yi);
110
+ float kl1 = -xi * log(mi / xi);
111
+ float kl2 = -yi * log(mi / yi);
112
+ accu += kl1 + kl2;
113
+ }
114
+ return 0.5 * accu;
115
+ }
116
+
117
+ } // namespace faiss
@@ -7,16 +7,15 @@
7
7
 
8
8
  // -*- c++ -*-
9
9
 
10
- #include <faiss/utils/distances.h>
10
+ #include <faiss/utils/extra_distances.h>
11
11
 
12
+ #include <omp.h>
12
13
  #include <algorithm>
13
14
  #include <cmath>
14
- #include <omp.h>
15
-
16
15
 
17
- #include <faiss/utils/utils.h>
18
- #include <faiss/impl/FaissAssert.h>
19
16
  #include <faiss/impl/AuxIndexStructures.h>
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/utils/utils.h>
20
19
 
21
20
  namespace faiss {
22
21
 
@@ -24,140 +23,43 @@ namespace faiss {
24
23
  * Distance functions (other than L2 and IP)
25
24
  ***************************************************************************/
26
25
 
27
- struct VectorDistanceL2 {
28
- size_t d;
29
-
30
- float operator () (const float *x, const float *y) const {
31
- return fvec_L2sqr (x, y, d);
32
- }
33
- };
34
-
35
- struct VectorDistanceL1 {
36
- size_t d;
37
-
38
- float operator () (const float *x, const float *y) const {
39
- return fvec_L1 (x, y, d);
40
- }
41
- };
42
-
43
- struct VectorDistanceLinf {
44
- size_t d;
45
-
46
- float operator () (const float *x, const float *y) const {
47
- return fvec_Linf (x, y, d);
48
- /*
49
- float vmax = 0;
50
- for (size_t i = 0; i < d; i++) {
51
- float diff = fabs (x[i] - y[i]);
52
- if (diff > vmax) vmax = diff;
53
- }
54
- return vmax;*/
55
- }
56
- };
57
-
58
- struct VectorDistanceLp {
59
- size_t d;
60
- const float p;
61
-
62
- float operator () (const float *x, const float *y) const {
63
- float accu = 0;
64
- for (size_t i = 0; i < d; i++) {
65
- float diff = fabs (x[i] - y[i]);
66
- accu += powf (diff, p);
67
- }
68
- return accu;
69
- }
70
- };
71
-
72
- struct VectorDistanceCanberra {
73
- size_t d;
74
-
75
- float operator () (const float *x, const float *y) const {
76
- float accu = 0;
77
- for (size_t i = 0; i < d; i++) {
78
- float xi = x[i], yi = y[i];
79
- accu += fabs (xi - yi) / (fabs(xi) + fabs(yi));
80
- }
81
- return accu;
82
- }
83
- };
84
-
85
- struct VectorDistanceBrayCurtis {
86
- size_t d;
87
-
88
- float operator () (const float *x, const float *y) const {
89
- float accu_num = 0, accu_den = 0;
90
- for (size_t i = 0; i < d; i++) {
91
- float xi = x[i], yi = y[i];
92
- accu_num += fabs (xi - yi);
93
- accu_den += fabs (xi + yi);
94
- }
95
- return accu_num / accu_den;
96
- }
97
- };
98
-
99
- struct VectorDistanceJensenShannon {
100
- size_t d;
101
-
102
- float operator () (const float *x, const float *y) const {
103
- float accu = 0;
104
-
105
- for (size_t i = 0; i < d; i++) {
106
- float xi = x[i], yi = y[i];
107
- float mi = 0.5 * (xi + yi);
108
- float kl1 = - xi * log(mi / xi);
109
- float kl2 = - yi * log(mi / yi);
110
- accu += kl1 + kl2;
111
- }
112
- return 0.5 * accu;
113
- }
114
- };
115
-
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-
124
-
125
26
  namespace {
126
27
 
127
- template<class VD>
128
- void pairwise_extra_distances_template (
129
- VD vd,
130
- int64_t nq, const float *xq,
131
- int64_t nb, const float *xb,
132
- float *dis,
133
- int64_t ldq, int64_t ldb, int64_t ldd)
134
- {
135
-
136
- #pragma omp parallel for if(nq > 10)
28
+ template <class VD>
29
+ void pairwise_extra_distances_template(
30
+ VD vd,
31
+ int64_t nq,
32
+ const float* xq,
33
+ int64_t nb,
34
+ const float* xb,
35
+ float* dis,
36
+ int64_t ldq,
37
+ int64_t ldb,
38
+ int64_t ldd) {
39
+ #pragma omp parallel for if (nq > 10)
137
40
  for (int64_t i = 0; i < nq; i++) {
138
- const float *xqi = xq + i * ldq;
139
- const float *xbj = xb;
140
- float *disi = dis + ldd * i;
41
+ const float* xqi = xq + i * ldq;
42
+ const float* xbj = xb;
43
+ float* disi = dis + ldd * i;
141
44
 
142
45
  for (int64_t j = 0; j < nb; j++) {
143
- disi[j] = vd (xqi, xbj);
46
+ disi[j] = vd(xqi, xbj);
144
47
  xbj += ldb;
145
48
  }
146
49
  }
147
50
  }
148
51
 
149
-
150
- template<class VD>
151
- void knn_extra_metrics_template (
52
+ template <class VD>
53
+ void knn_extra_metrics_template(
152
54
  VD vd,
153
- const float * x,
154
- const float * y,
155
- size_t nx, size_t ny,
156
- float_maxheap_array_t * res)
157
- {
55
+ const float* x,
56
+ const float* y,
57
+ size_t nx,
58
+ size_t ny,
59
+ float_maxheap_array_t* res) {
158
60
  size_t k = res->k;
159
61
  size_t d = vd.d;
160
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
62
+ size_t check_period = InterruptCallback::get_period_hint(ny * d);
161
63
  check_period *= omp_get_max_threads();
162
64
 
163
65
  for (size_t i0 = 0; i0 < nx; i0 += check_period) {
@@ -165,90 +67,84 @@ void knn_extra_metrics_template (
165
67
 
166
68
  #pragma omp parallel for
167
69
  for (int64_t i = i0; i < i1; i++) {
168
- const float * x_i = x + i * d;
169
- const float * y_j = y;
70
+ const float* x_i = x + i * d;
71
+ const float* y_j = y;
170
72
  size_t j;
171
- float * simi = res->get_val(i);
172
- int64_t * idxi = res->get_ids (i);
73
+ float* simi = res->get_val(i);
74
+ int64_t* idxi = res->get_ids(i);
173
75
 
174
- maxheap_heapify (k, simi, idxi);
76
+ maxheap_heapify(k, simi, idxi);
175
77
  for (j = 0; j < ny; j++) {
176
- float disij = vd (x_i, y_j);
78
+ float disij = vd(x_i, y_j);
177
79
 
178
80
  if (disij < simi[0]) {
179
- maxheap_replace_top (k, simi, idxi, disij, j);
81
+ maxheap_replace_top(k, simi, idxi, disij, j);
180
82
  }
181
83
  y_j += d;
182
84
  }
183
- maxheap_reorder (k, simi, idxi);
85
+ maxheap_reorder(k, simi, idxi);
184
86
  }
185
- InterruptCallback::check ();
87
+ InterruptCallback::check();
186
88
  }
187
-
188
89
  }
189
90
 
190
-
191
- template<class VD>
91
+ template <class VD>
192
92
  struct ExtraDistanceComputer : DistanceComputer {
193
93
  VD vd;
194
94
  Index::idx_t nb;
195
- const float *q;
196
- const float *b;
95
+ const float* q;
96
+ const float* b;
197
97
 
198
- float operator () (idx_t i) override {
199
- return vd (q, b + i * vd.d);
98
+ float operator()(idx_t i) override {
99
+ return vd(q, b + i * vd.d);
200
100
  }
201
101
 
202
102
  float symmetric_dis(idx_t i, idx_t j) override {
203
- return vd (b + j * vd.d, b + i * vd.d);
103
+ return vd(b + j * vd.d, b + i * vd.d);
204
104
  }
205
105
 
206
- ExtraDistanceComputer(const VD & vd, const float *xb,
207
- size_t nb, const float *q = nullptr)
208
- : vd(vd), nb(nb), q(q), b(xb) {}
106
+ ExtraDistanceComputer(
107
+ const VD& vd,
108
+ const float* xb,
109
+ size_t nb,
110
+ const float* q = nullptr)
111
+ : vd(vd), nb(nb), q(q), b(xb) {}
209
112
 
210
- void set_query(const float *x) override {
113
+ void set_query(const float* x) override {
211
114
  q = x;
212
115
  }
213
116
  };
214
117
 
215
-
216
-
217
-
218
-
219
-
220
-
221
-
222
-
223
-
224
-
225
-
226
-
227
-
228
-
229
-
230
118
  } // anonymous namespace
231
119
 
232
- void pairwise_extra_distances (
233
- int64_t d,
234
- int64_t nq, const float *xq,
235
- int64_t nb, const float *xb,
236
- MetricType mt, float metric_arg,
237
- float *dis,
238
- int64_t ldq, int64_t ldb, int64_t ldd)
239
- {
240
- if (nq == 0 || nb == 0) return;
241
- if (ldq == -1) ldq = d;
242
- if (ldb == -1) ldb = d;
243
- if (ldd == -1) ldd = nb;
244
-
245
- switch(mt) {
246
- #define HANDLE_VAR(kw) \
247
- case METRIC_ ## kw: { \
248
- VectorDistance ## kw vd = {(size_t)d}; \
249
- pairwise_extra_distances_template (vd, nq, xq, nb, xb, \
250
- dis, ldq, ldb, ldd); \
251
- break; \
120
+ void pairwise_extra_distances(
121
+ int64_t d,
122
+ int64_t nq,
123
+ const float* xq,
124
+ int64_t nb,
125
+ const float* xb,
126
+ MetricType mt,
127
+ float metric_arg,
128
+ float* dis,
129
+ int64_t ldq,
130
+ int64_t ldb,
131
+ int64_t ldd) {
132
+ if (nq == 0 || nb == 0)
133
+ return;
134
+ if (ldq == -1)
135
+ ldq = d;
136
+ if (ldb == -1)
137
+ ldb = d;
138
+ if (ldd == -1)
139
+ ldd = nb;
140
+
141
+ switch (mt) {
142
+ #define HANDLE_VAR(kw) \
143
+ case METRIC_##kw: { \
144
+ VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
145
+ pairwise_extra_distances_template( \
146
+ vd, nq, xq, nb, xb, dis, ldq, ldb, ldd); \
147
+ break; \
252
148
  }
253
149
  HANDLE_VAR(L2);
254
150
  HANDLE_VAR(L1);
@@ -256,33 +152,28 @@ void pairwise_extra_distances (
256
152
  HANDLE_VAR(Canberra);
257
153
  HANDLE_VAR(BrayCurtis);
258
154
  HANDLE_VAR(JensenShannon);
155
+ HANDLE_VAR(Lp);
259
156
  #undef HANDLE_VAR
260
- case METRIC_Lp: {
261
- VectorDistanceLp vd = {(size_t)d, metric_arg};
262
- pairwise_extra_distances_template (vd, nq, xq, nb, xb,
263
- dis, ldq, ldb, ldd);
264
- break;
157
+ default:
158
+ FAISS_THROW_MSG("metric type not implemented");
265
159
  }
266
- default:
267
- FAISS_THROW_MSG ("metric type not implemented");
268
- }
269
-
270
160
  }
271
161
 
272
- void knn_extra_metrics (
273
- const float * x,
274
- const float * y,
275
- size_t d, size_t nx, size_t ny,
276
- MetricType mt, float metric_arg,
277
- float_maxheap_array_t * res)
278
- {
279
-
280
- switch(mt) {
281
- #define HANDLE_VAR(kw) \
282
- case METRIC_ ## kw: { \
283
- VectorDistance ## kw vd = {(size_t)d}; \
284
- knn_extra_metrics_template (vd, x, y, nx, ny, res); \
285
- break; \
162
+ void knn_extra_metrics(
163
+ const float* x,
164
+ const float* y,
165
+ size_t d,
166
+ size_t nx,
167
+ size_t ny,
168
+ MetricType mt,
169
+ float metric_arg,
170
+ float_maxheap_array_t* res) {
171
+ switch (mt) {
172
+ #define HANDLE_VAR(kw) \
173
+ case METRIC_##kw: { \
174
+ VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
175
+ knn_extra_metrics_template(vd, x, y, nx, ny, res); \
176
+ break; \
286
177
  }
287
178
  HANDLE_VAR(L2);
288
179
  HANDLE_VAR(L1);
@@ -290,29 +181,25 @@ void knn_extra_metrics (
290
181
  HANDLE_VAR(Canberra);
291
182
  HANDLE_VAR(BrayCurtis);
292
183
  HANDLE_VAR(JensenShannon);
184
+ HANDLE_VAR(Lp);
293
185
  #undef HANDLE_VAR
294
- case METRIC_Lp: {
295
- VectorDistanceLp vd = {(size_t)d, metric_arg};
296
- knn_extra_metrics_template (vd, x, y, nx, ny, res);
297
- break;
298
- }
299
- default:
300
- FAISS_THROW_MSG ("metric type not implemented");
186
+ default:
187
+ FAISS_THROW_MSG("metric type not implemented");
301
188
  }
302
-
303
189
  }
304
190
 
305
- DistanceComputer *get_extra_distance_computer (
191
+ DistanceComputer* get_extra_distance_computer(
306
192
  size_t d,
307
- MetricType mt, float metric_arg,
308
- size_t nb, const float *xb)
309
- {
310
-
311
- switch(mt) {
312
- #define HANDLE_VAR(kw) \
313
- case METRIC_ ## kw: { \
314
- VectorDistance ## kw vd = {(size_t)d}; \
315
- return new ExtraDistanceComputer<VectorDistance ## kw>(vd, xb, nb); \
193
+ MetricType mt,
194
+ float metric_arg,
195
+ size_t nb,
196
+ const float* xb) {
197
+ switch (mt) {
198
+ #define HANDLE_VAR(kw) \
199
+ case METRIC_##kw: { \
200
+ VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
201
+ return new ExtraDistanceComputer<VectorDistance<METRIC_##kw>>( \
202
+ vd, xb, nb); \
316
203
  }
317
204
  HANDLE_VAR(L2);
318
205
  HANDLE_VAR(L1);
@@ -320,17 +207,11 @@ DistanceComputer *get_extra_distance_computer (
320
207
  HANDLE_VAR(Canberra);
321
208
  HANDLE_VAR(BrayCurtis);
322
209
  HANDLE_VAR(JensenShannon);
210
+ HANDLE_VAR(Lp);
323
211
  #undef HANDLE_VAR
324
- case METRIC_Lp: {
325
- VectorDistanceLp vd = {(size_t)d, metric_arg};
326
- return new ExtraDistanceComputer<VectorDistanceLp> (vd, xb, nb);
327
- break;
328
- }
329
- default:
330
- FAISS_THROW_MSG ("metric type not implemented");
212
+ default:
213
+ FAISS_THROW_MSG("metric type not implemented");
331
214
  }
332
-
333
215
  }
334
216
 
335
-
336
217
  } // namespace faiss