faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -9,7 +9,9 @@
9
9
  * and inner product */
10
10
 
11
11
  #include <faiss/MetricType.h>
12
+ #include <faiss/impl/FaissAssert.h>
12
13
  #include <faiss/utils/distances.h>
14
+ #include <cmath>
13
15
  #include <type_traits>
14
16
 
15
17
  namespace faiss {
@@ -130,4 +132,72 @@ inline float VectorDistance<METRIC_Jaccard>::operator()(
130
132
  return accu_num / accu_den;
131
133
  }
132
134
 
135
+ template <>
136
+ inline float VectorDistance<METRIC_NaNEuclidean>::operator()(
137
+ const float* x,
138
+ const float* y) const {
139
+ // https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html
140
+ float accu = 0;
141
+ size_t present = 0;
142
+ for (size_t i = 0; i < d; i++) {
143
+ if (!std::isnan(x[i]) && !std::isnan(y[i])) {
144
+ float diff = x[i] - y[i];
145
+ accu += diff * diff;
146
+ present++;
147
+ }
148
+ }
149
+ if (present == 0) {
150
+ return NAN;
151
+ }
152
+ return float(d) / float(present) * accu;
153
+ }
154
+
155
+ template <>
156
+ inline float VectorDistance<METRIC_ABS_INNER_PRODUCT>::operator()(
157
+ const float* x,
158
+ const float* y) const {
159
+ float accu = 0;
160
+ for (size_t i = 0; i < d; i++) {
161
+ accu += fabs(x[i] * y[i]);
162
+ }
163
+ return accu;
164
+ }
165
+
166
+ /***************************************************************************
167
+ * Dispatching function that takes a metric type and a consumer object
168
+ * the consumer object should contain a retun type T and a operation template
169
+ * function f() that is called to perform the operation. The first argument
170
+ * of the function is the VectorDistance object. The rest are passed in as is.
171
+ **************************************************************************/
172
+
173
+ template <class Consumer, class... Types>
174
+ typename Consumer::T dispatch_VectorDistance(
175
+ size_t d,
176
+ MetricType metric,
177
+ float metric_arg,
178
+ Consumer& consumer,
179
+ Types... args) {
180
+ switch (metric) {
181
+ #define DISPATCH_VD(mt) \
182
+ case mt: { \
183
+ VectorDistance<mt> vd = {d, metric_arg}; \
184
+ return consumer.template f<VectorDistance<mt>>(vd, args...); \
185
+ }
186
+ DISPATCH_VD(METRIC_INNER_PRODUCT);
187
+ DISPATCH_VD(METRIC_L2);
188
+ DISPATCH_VD(METRIC_L1);
189
+ DISPATCH_VD(METRIC_Linf);
190
+ DISPATCH_VD(METRIC_Lp);
191
+ DISPATCH_VD(METRIC_Canberra);
192
+ DISPATCH_VD(METRIC_BrayCurtis);
193
+ DISPATCH_VD(METRIC_JensenShannon);
194
+ DISPATCH_VD(METRIC_Jaccard);
195
+ DISPATCH_VD(METRIC_NaNEuclidean);
196
+ DISPATCH_VD(METRIC_ABS_INNER_PRODUCT);
197
+ default:
198
+ FAISS_THROW_FMT("Invalid metric %d", metric);
199
+ }
200
+ #undef DISPATCH_VD
201
+ }
202
+
133
203
  } // namespace faiss
@@ -26,73 +26,77 @@ namespace faiss {
26
26
 
27
27
  namespace {
28
28
 
29
- template <class VD>
30
- void pairwise_extra_distances_template(
31
- VD vd,
32
- int64_t nq,
33
- const float* xq,
34
- int64_t nb,
35
- const float* xb,
36
- float* dis,
37
- int64_t ldq,
38
- int64_t ldb,
39
- int64_t ldd) {
29
+ struct Run_pairwise_extra_distances {
30
+ using T = void;
31
+
32
+ template <class VD>
33
+ void f(VD vd,
34
+ int64_t nq,
35
+ const float* xq,
36
+ int64_t nb,
37
+ const float* xb,
38
+ float* dis,
39
+ int64_t ldq,
40
+ int64_t ldb,
41
+ int64_t ldd) {
40
42
  #pragma omp parallel for if (nq > 10)
41
- for (int64_t i = 0; i < nq; i++) {
42
- const float* xqi = xq + i * ldq;
43
- const float* xbj = xb;
44
- float* disi = dis + ldd * i;
45
-
46
- for (int64_t j = 0; j < nb; j++) {
47
- disi[j] = vd(xqi, xbj);
48
- xbj += ldb;
43
+ for (int64_t i = 0; i < nq; i++) {
44
+ const float* xqi = xq + i * ldq;
45
+ const float* xbj = xb;
46
+ float* disi = dis + ldd * i;
47
+
48
+ for (int64_t j = 0; j < nb; j++) {
49
+ disi[j] = vd(xqi, xbj);
50
+ xbj += ldb;
51
+ }
49
52
  }
50
53
  }
51
- }
52
-
53
- template <class VD, class C>
54
- void knn_extra_metrics_template(
55
- VD vd,
56
- const float* x,
57
- const float* y,
58
- size_t nx,
59
- size_t ny,
60
- HeapArray<C>* res) {
61
- size_t k = res->k;
62
- size_t d = vd.d;
63
- size_t check_period = InterruptCallback::get_period_hint(ny * d);
64
- check_period *= omp_get_max_threads();
54
+ };
65
55
 
66
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
67
- size_t i1 = std::min(i0 + check_period, nx);
56
+ struct Run_knn_extra_metrics {
57
+ using T = void;
58
+ template <class VD>
59
+ void f(VD vd,
60
+ const float* x,
61
+ const float* y,
62
+ size_t nx,
63
+ size_t ny,
64
+ size_t k,
65
+ float* distances,
66
+ int64_t* labels) {
67
+ size_t d = vd.d;
68
+ using C = typename VD::C;
69
+ size_t check_period = InterruptCallback::get_period_hint(ny * d);
70
+ check_period *= omp_get_max_threads();
71
+
72
+ for (size_t i0 = 0; i0 < nx; i0 += check_period) {
73
+ size_t i1 = std::min(i0 + check_period, nx);
68
74
 
69
75
  #pragma omp parallel for
70
- for (int64_t i = i0; i < i1; i++) {
71
- const float* x_i = x + i * d;
72
- const float* y_j = y;
73
- size_t j;
74
- float* simi = res->get_val(i);
75
- int64_t* idxi = res->get_ids(i);
76
-
77
- // maxheap_heapify(k, simi, idxi);
78
- heap_heapify<C>(k, simi, idxi);
79
- for (j = 0; j < ny; j++) {
80
- float disij = vd(x_i, y_j);
81
-
82
- // if (disij < simi[0]) {
83
- if ((!vd.is_similarity && (disij < simi[0])) ||
84
- (vd.is_similarity && (disij > simi[0]))) {
85
- // maxheap_replace_top(k, simi, idxi, disij, j);
86
- heap_replace_top<C>(k, simi, idxi, disij, j);
76
+ for (int64_t i = i0; i < i1; i++) {
77
+ const float* x_i = x + i * d;
78
+ const float* y_j = y;
79
+ size_t j;
80
+ float* simi = distances + k * i;
81
+ int64_t* idxi = labels + k * i;
82
+
83
+ // maxheap_heapify(k, simi, idxi);
84
+ heap_heapify<C>(k, simi, idxi);
85
+ for (j = 0; j < ny; j++) {
86
+ float disij = vd(x_i, y_j);
87
+
88
+ if (C::cmp(simi[0], disij)) {
89
+ heap_replace_top<C>(k, simi, idxi, disij, j);
90
+ }
91
+ y_j += d;
87
92
  }
88
- y_j += d;
93
+ // maxheap_reorder(k, simi, idxi);
94
+ heap_reorder<C>(k, simi, idxi);
89
95
  }
90
- // maxheap_reorder(k, simi, idxi);
91
- heap_reorder<C>(k, simi, idxi);
96
+ InterruptCallback::check();
92
97
  }
93
- InterruptCallback::check();
94
98
  }
95
- }
99
+ };
96
100
 
97
101
  template <class VD>
98
102
  struct ExtraDistanceComputer : FlatCodesDistanceComputer {
@@ -125,6 +129,19 @@ struct ExtraDistanceComputer : FlatCodesDistanceComputer {
125
129
  }
126
130
  };
127
131
 
132
+ struct Run_get_distance_computer {
133
+ using T = FlatCodesDistanceComputer*;
134
+
135
+ template <class VD>
136
+ FlatCodesDistanceComputer* f(
137
+ VD vd,
138
+ const float* xb,
139
+ size_t nb,
140
+ const float* q = nullptr) {
141
+ return new ExtraDistanceComputer<VD>(vd, xb, nb, q);
142
+ }
143
+ };
144
+
128
145
  } // anonymous namespace
129
146
 
130
147
  void pairwise_extra_distances(
@@ -148,29 +165,11 @@ void pairwise_extra_distances(
148
165
  if (ldd == -1)
149
166
  ldd = nb;
150
167
 
151
- switch (mt) {
152
- #define HANDLE_VAR(kw) \
153
- case METRIC_##kw: { \
154
- VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
155
- pairwise_extra_distances_template( \
156
- vd, nq, xq, nb, xb, dis, ldq, ldb, ldd); \
157
- break; \
158
- }
159
- HANDLE_VAR(L2);
160
- HANDLE_VAR(L1);
161
- HANDLE_VAR(Linf);
162
- HANDLE_VAR(Canberra);
163
- HANDLE_VAR(BrayCurtis);
164
- HANDLE_VAR(JensenShannon);
165
- HANDLE_VAR(Lp);
166
- HANDLE_VAR(Jaccard);
167
- #undef HANDLE_VAR
168
- default:
169
- FAISS_THROW_MSG("metric type not implemented");
170
- }
168
+ Run_pairwise_extra_distances run;
169
+ dispatch_VectorDistance(
170
+ d, mt, metric_arg, run, nq, xq, nb, xb, dis, ldq, ldb, ldd);
171
171
  }
172
172
 
173
- template <class C>
174
173
  void knn_extra_metrics(
175
174
  const float* x,
176
175
  const float* y,
@@ -179,73 +178,22 @@ void knn_extra_metrics(
179
178
  size_t ny,
180
179
  MetricType mt,
181
180
  float metric_arg,
182
- HeapArray<C>* res) {
183
- switch (mt) {
184
- #define HANDLE_VAR(kw) \
185
- case METRIC_##kw: { \
186
- VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
187
- knn_extra_metrics_template(vd, x, y, nx, ny, res); \
188
- break; \
189
- }
190
- HANDLE_VAR(L2);
191
- HANDLE_VAR(L1);
192
- HANDLE_VAR(Linf);
193
- HANDLE_VAR(Canberra);
194
- HANDLE_VAR(BrayCurtis);
195
- HANDLE_VAR(JensenShannon);
196
- HANDLE_VAR(Lp);
197
- HANDLE_VAR(Jaccard);
198
- #undef HANDLE_VAR
199
- default:
200
- FAISS_THROW_MSG("metric type not implemented");
201
- }
181
+ size_t k,
182
+ float* distances,
183
+ int64_t* indexes) {
184
+ Run_knn_extra_metrics run;
185
+ dispatch_VectorDistance(
186
+ d, mt, metric_arg, run, x, y, nx, ny, k, distances, indexes);
202
187
  }
203
188
 
204
- template void knn_extra_metrics<CMax<float, int64_t>>(
205
- const float* x,
206
- const float* y,
207
- size_t d,
208
- size_t nx,
209
- size_t ny,
210
- MetricType mt,
211
- float metric_arg,
212
- HeapArray<CMax<float, int64_t>>* res);
213
-
214
- template void knn_extra_metrics<CMin<float, int64_t>>(
215
- const float* x,
216
- const float* y,
217
- size_t d,
218
- size_t nx,
219
- size_t ny,
220
- MetricType mt,
221
- float metric_arg,
222
- HeapArray<CMin<float, int64_t>>* res);
223
-
224
189
  FlatCodesDistanceComputer* get_extra_distance_computer(
225
190
  size_t d,
226
191
  MetricType mt,
227
192
  float metric_arg,
228
193
  size_t nb,
229
194
  const float* xb) {
230
- switch (mt) {
231
- #define HANDLE_VAR(kw) \
232
- case METRIC_##kw: { \
233
- VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
234
- return new ExtraDistanceComputer<VectorDistance<METRIC_##kw>>( \
235
- vd, xb, nb); \
236
- }
237
- HANDLE_VAR(L2);
238
- HANDLE_VAR(L1);
239
- HANDLE_VAR(Linf);
240
- HANDLE_VAR(Canberra);
241
- HANDLE_VAR(BrayCurtis);
242
- HANDLE_VAR(JensenShannon);
243
- HANDLE_VAR(Lp);
244
- HANDLE_VAR(Jaccard);
245
- #undef HANDLE_VAR
246
- default:
247
- FAISS_THROW_MSG("metric type not implemented");
248
- }
195
+ Run_get_distance_computer run;
196
+ return dispatch_VectorDistance(d, mt, metric_arg, run, xb, nb);
249
197
  }
250
198
 
251
199
  } // namespace faiss
@@ -33,7 +33,6 @@ void pairwise_extra_distances(
33
33
  int64_t ldb = -1,
34
34
  int64_t ldd = -1);
35
35
 
36
- template <class C>
37
36
  void knn_extra_metrics(
38
37
  const float* x,
39
38
  const float* y,
@@ -42,7 +41,9 @@ void knn_extra_metrics(
42
41
  size_t ny,
43
42
  MetricType mt,
44
43
  float metric_arg,
45
- HeapArray<C>* res);
44
+ size_t k,
45
+ float* distances,
46
+ int64_t* indexes);
46
47
 
47
48
  /** get a DistanceComputer that refers to this type of distance and
48
49
  * indexes a flat array of size nb */
@@ -0,0 +1,29 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <arm_neon.h>
11
+ #include <cstdint>
12
+
13
+ namespace faiss {
14
+
15
+ inline uint16_t encode_fp16(float x) {
16
+ float32x4_t fx4 = vdupq_n_f32(x);
17
+ float16x4_t f16x4 = vcvt_f16_f32(fx4);
18
+ uint16x4_t ui16x4 = vreinterpret_u16_f16(f16x4);
19
+ return vduph_lane_u16(ui16x4, 3);
20
+ }
21
+
22
+ inline float decode_fp16(uint16_t x) {
23
+ uint16x4_t ui16x4 = vdup_n_u16(x);
24
+ float16x4_t f16x4 = vreinterpret_f16_u16(ui16x4);
25
+ float32x4_t fx4 = vcvt_f32_f16(f16x4);
26
+ return vdups_laneq_f32(fx4, 3);
27
+ }
28
+
29
+ } // namespace faiss
@@ -13,6 +13,8 @@
13
13
 
14
14
  #if defined(__F16C__)
15
15
  #include <faiss/utils/fp16-fp16c.h>
16
+ #elif defined(__aarch64__)
17
+ #include <faiss/utils/fp16-arm.h>
16
18
  #else
17
19
  #include <faiss/utils/fp16-inl.h>
18
20
  #endif