faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -76,6 +76,12 @@ bool getTensorCoreSupport(int device);
76
76
  /// Equivalent to getTensorCoreSupport(getCurrentDevice())
77
77
  bool getTensorCoreSupportCurrentDevice();
78
78
 
79
+ /// Returns the warp size of the given GPU device
80
+ int getWarpSize(int device);
81
+
82
+ /// Equivalent to getWarpSize(getCurrentDevice())
83
+ int getWarpSizeCurrentDevice();
84
+
79
85
  /// Returns the amount of currently available memory on the given device
80
86
  size_t getFreeMemory(int device);
81
87
 
@@ -0,0 +1,75 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+ /*
8
+ * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ *
10
+ * Licensed under the Apache License, Version 2.0 (the "License");
11
+ * you may not use this file except in compliance with the License.
12
+ * You may obtain a copy of the License at
13
+ *
14
+ * http://www.apache.org/licenses/LICENSE-2.0
15
+ *
16
+ * Unless required by applicable law or agreed to in writing, software
17
+ * distributed under the License is distributed on an "AS IS" BASIS,
18
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ * See the License for the specific language governing permissions and
20
+ * limitations under the License.
21
+ */
22
+
23
+ #pragma once
24
+
25
+ #include <faiss/MetricType.h>
26
+ #include <faiss/gpu/GpuResources.h>
27
+ #include <faiss/gpu/utils/Tensor.cuh>
28
+
29
+ #include <raft/distance/distance_types.hpp>
30
+
31
+ #pragma GCC visibility push(default)
32
+ namespace faiss {
33
+ namespace gpu {
34
+
35
+ inline raft::distance::DistanceType metricFaissToRaft(
36
+ MetricType metric,
37
+ bool exactDistance) {
38
+ switch (metric) {
39
+ case MetricType::METRIC_INNER_PRODUCT:
40
+ return raft::distance::DistanceType::InnerProduct;
41
+ case MetricType::METRIC_L2:
42
+ return raft::distance::DistanceType::L2Expanded;
43
+ case MetricType::METRIC_L1:
44
+ return raft::distance::DistanceType::L1;
45
+ case MetricType::METRIC_Linf:
46
+ return raft::distance::DistanceType::Linf;
47
+ case MetricType::METRIC_Lp:
48
+ return raft::distance::DistanceType::LpUnexpanded;
49
+ case MetricType::METRIC_Canberra:
50
+ return raft::distance::DistanceType::Canberra;
51
+ case MetricType::METRIC_BrayCurtis:
52
+ return raft::distance::DistanceType::BrayCurtis;
53
+ case MetricType::METRIC_JensenShannon:
54
+ return raft::distance::DistanceType::JensenShannon;
55
+ default:
56
+ RAFT_FAIL("Distance type not supported");
57
+ }
58
+ }
59
+
60
+ /// Identify matrix rows containing non NaN values. validRows[i] is false if row
61
+ /// i contains a NaN value and true otherwise.
62
+ void validRowIndices(
63
+ GpuResources* res,
64
+ Tensor<float, 2, true>& vecs,
65
+ bool* validRows);
66
+
67
+ /// Filter out matrix rows containing NaN values. The vectors and indices are
68
+ /// updated in-place.
69
+ idx_t inplaceGatherFilteredRows(
70
+ GpuResources* res,
71
+ Tensor<float, 2, true>& vecs,
72
+ Tensor<idx_t, 1, true>& indices);
73
+ } // namespace gpu
74
+ } // namespace faiss
75
+ #pragma GCC visibility pop
@@ -14,7 +14,10 @@ namespace faiss {
14
14
  namespace gpu {
15
15
 
16
16
  KernelTimer::KernelTimer(cudaStream_t stream)
17
- : startEvent_(0), stopEvent_(0), stream_(stream), valid_(true) {
17
+ : startEvent_(nullptr),
18
+ stopEvent_(nullptr),
19
+ stream_(stream),
20
+ valid_(true) {
18
21
  CUDA_VERIFY(cudaEventCreate(&startEvent_));
19
22
  CUDA_VERIFY(cudaEventCreate(&stopEvent_));
20
23
 
@@ -18,7 +18,7 @@ class KernelTimer {
18
18
  public:
19
19
  /// Constructor starts the timer and adds an event into the current
20
20
  /// device stream
21
- KernelTimer(cudaStream_t stream = 0);
21
+ KernelTimer(cudaStream_t stream = nullptr);
22
22
 
23
23
  /// Destructor releases event resources
24
24
  ~KernelTimer();
@@ -261,7 +261,7 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
261
261
  is_trained, "The additive quantizer is not trained yet.");
262
262
 
263
263
  // standard additive quantizer decoding
264
- #pragma omp parallel for if (n > 1000)
264
+ #pragma omp parallel for if (n > 100)
265
265
  for (int64_t i = 0; i < n; i++) {
266
266
  BitstringReader bsr(code + i * code_size, code_size);
267
267
  float* xi = x + i * d;
@@ -370,6 +370,8 @@ void AdditiveQuantizer::compute_LUT(
370
370
 
371
371
  namespace {
372
372
 
373
+ /* compute inner products of one query with all centroids, given a look-up
374
+ * table of all inner producst with codebook entries */
373
375
  void compute_inner_prod_with_LUT(
374
376
  const AdditiveQuantizer& aq,
375
377
  const float* LUT,
@@ -49,11 +49,11 @@ struct AdditiveQuantizer : Quantizer {
49
49
  /// encode a norm into norm_bits bits
50
50
  uint64_t encode_norm(float norm) const;
51
51
 
52
- uint32_t encode_qcint(
53
- float x) const; ///< encode norm by non-uniform scalar quantization
52
+ /// encode norm by non-uniform scalar quantization
53
+ uint32_t encode_qcint(float x) const;
54
54
 
55
- float decode_qcint(uint32_t c)
56
- const; ///< decode norm by non-uniform scalar quantization
55
+ /// decode norm by non-uniform scalar quantization
56
+ float decode_qcint(uint32_t c) const;
57
57
 
58
58
  /// Encodes how search is performed and how vectors are encoded
59
59
  enum Search_type_t {
@@ -203,4 +203,4 @@ struct AdditiveQuantizer : Quantizer {
203
203
  virtual ~AdditiveQuantizer();
204
204
  };
205
205
 
206
- }; // namespace faiss
206
+ } // namespace faiss
@@ -230,10 +230,35 @@ bool InterruptCallback::is_interrupted() {
230
230
 
231
231
  size_t InterruptCallback::get_period_hint(size_t flops) {
232
232
  if (!instance.get()) {
233
- return 1L << 30; // never check
233
+ return (size_t)1 << 30; // never check
234
234
  }
235
235
  // for 10M flops, it is reasonable to check once every 10 iterations
236
236
  return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
237
237
  }
238
238
 
239
+ void TimeoutCallback::set_timeout(double timeout_in_seconds) {
240
+ timeout = timeout_in_seconds;
241
+ start = std::chrono::steady_clock::now();
242
+ }
243
+
244
+ bool TimeoutCallback::want_interrupt() {
245
+ if (timeout == 0) {
246
+ return false;
247
+ }
248
+ auto end = std::chrono::steady_clock::now();
249
+ std::chrono::duration<float, std::milli> duration = end - start;
250
+ float elapsed_in_seconds = duration.count() / 1000.0;
251
+ if (elapsed_in_seconds > timeout) {
252
+ timeout = 0;
253
+ return true;
254
+ }
255
+ return false;
256
+ }
257
+
258
+ void TimeoutCallback::reset(double timeout_in_seconds) {
259
+ auto tc(new faiss::TimeoutCallback());
260
+ faiss::InterruptCallback::instance.reset(tc);
261
+ tc->set_timeout(timeout_in_seconds);
262
+ }
263
+
239
264
  } // namespace faiss
@@ -41,7 +41,6 @@ struct RangeSearchResult {
41
41
 
42
42
  /// called when lims contains the nb of elements result entries
43
43
  /// for each query
44
-
45
44
  virtual void do_allocation();
46
45
 
47
46
  virtual ~RangeSearchResult();
@@ -123,7 +122,7 @@ struct RangeSearchPartialResult : BufferList {
123
122
  void copy_result(bool incremental = false);
124
123
 
125
124
  /// merge a set of PartialResult's into one RangeSearchResult
126
- /// on ouptut the partialresults are empty!
125
+ /// on output the partialresults are empty!
127
126
  static void merge(
128
127
  std::vector<RangeSearchPartialResult*>& partial_results,
129
128
  bool do_delete = true);
@@ -162,10 +161,18 @@ struct FAISS_API InterruptCallback {
162
161
  static size_t get_period_hint(size_t flops);
163
162
  };
164
163
 
164
+ struct TimeoutCallback : InterruptCallback {
165
+ std::chrono::time_point<std::chrono::steady_clock> start;
166
+ double timeout;
167
+ bool want_interrupt() override;
168
+ void set_timeout(double timeout_in_seconds);
169
+ static void reset(double timeout_in_seconds);
170
+ };
171
+
165
172
  /// set implementation optimized for fast access.
166
173
  struct VisitedTable {
167
174
  std::vector<uint8_t> visited;
168
- int visno;
175
+ uint8_t visno;
169
176
 
170
177
  explicit VisitedTable(int size) : visited(size), visno(1) {}
171
178
 
@@ -30,12 +30,81 @@ struct DistanceComputer {
30
30
  /// compute distance of vector i to current query
31
31
  virtual float operator()(idx_t i) = 0;
32
32
 
33
+ /// compute distances of current query to 4 stored vectors.
34
+ /// certain DistanceComputer implementations may benefit
35
+ /// heavily from this.
36
+ virtual void distances_batch_4(
37
+ const idx_t idx0,
38
+ const idx_t idx1,
39
+ const idx_t idx2,
40
+ const idx_t idx3,
41
+ float& dis0,
42
+ float& dis1,
43
+ float& dis2,
44
+ float& dis3) {
45
+ // compute first, assign next
46
+ const float d0 = this->operator()(idx0);
47
+ const float d1 = this->operator()(idx1);
48
+ const float d2 = this->operator()(idx2);
49
+ const float d3 = this->operator()(idx3);
50
+ dis0 = d0;
51
+ dis1 = d1;
52
+ dis2 = d2;
53
+ dis3 = d3;
54
+ }
55
+
33
56
  /// compute distance between two stored vectors
34
57
  virtual float symmetric_dis(idx_t i, idx_t j) = 0;
35
58
 
36
59
  virtual ~DistanceComputer() {}
37
60
  };
38
61
 
62
+ /* Wrap the distance computer into one that negates the
63
+ distances. This makes supporting INNER_PRODUCE search easier */
64
+
65
+ struct NegativeDistanceComputer : DistanceComputer {
66
+ /// owned by this
67
+ DistanceComputer* basedis;
68
+
69
+ explicit NegativeDistanceComputer(DistanceComputer* basedis)
70
+ : basedis(basedis) {}
71
+
72
+ void set_query(const float* x) override {
73
+ basedis->set_query(x);
74
+ }
75
+
76
+ /// compute distance of vector i to current query
77
+ float operator()(idx_t i) override {
78
+ return -(*basedis)(i);
79
+ }
80
+
81
+ void distances_batch_4(
82
+ const idx_t idx0,
83
+ const idx_t idx1,
84
+ const idx_t idx2,
85
+ const idx_t idx3,
86
+ float& dis0,
87
+ float& dis1,
88
+ float& dis2,
89
+ float& dis3) override {
90
+ basedis->distances_batch_4(
91
+ idx0, idx1, idx2, idx3, dis0, dis1, dis2, dis3);
92
+ dis0 = -dis0;
93
+ dis1 = -dis1;
94
+ dis2 = -dis2;
95
+ dis3 = -dis3;
96
+ }
97
+
98
+ /// compute distance between two stored vectors
99
+ float symmetric_dis(idx_t i, idx_t j) override {
100
+ return -basedis->symmetric_dis(i, j);
101
+ }
102
+
103
+ virtual ~NegativeDistanceComputer() {
104
+ delete basedis;
105
+ }
106
+ };
107
+
39
108
  /*************************************************************
40
109
  * Specialized version of the DistanceComputer when we know that codes are
41
110
  * laid out in a flat index.
@@ -49,7 +118,7 @@ struct FlatCodesDistanceComputer : DistanceComputer {
49
118
 
50
119
  FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
51
120
 
52
- float operator()(idx_t i) final {
121
+ float operator()(idx_t i) override {
53
122
  return distance_to_code(codes + i * code_size);
54
123
  }
55
124
 
@@ -94,13 +94,15 @@
94
94
  } \
95
95
  } while (false)
96
96
 
97
- #define FAISS_THROW_IF_NOT_MSG(X, MSG) \
97
+ #define FAISS_THROW_IF_MSG(X, MSG) \
98
98
  do { \
99
- if (!(X)) { \
99
+ if (X) { \
100
100
  FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \
101
101
  } \
102
102
  } while (false)
103
103
 
104
+ #define FAISS_THROW_IF_NOT_MSG(X, MSG) FAISS_THROW_IF_MSG(!(X), MSG)
105
+
104
106
  #define FAISS_THROW_IF_NOT_FMT(X, FMT, ...) \
105
107
  do { \
106
108
  if (!(X)) { \
@@ -1,3 +1,4 @@
1
+
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -40,42 +41,20 @@ class FaissException : public std::exception {
40
41
  void handleExceptions(
41
42
  std::vector<std::pair<int, std::exception_ptr>>& exceptions);
42
43
 
43
- /** bare-bones unique_ptr
44
- * this one deletes with delete [] */
45
- template <class T>
46
- struct ScopeDeleter {
47
- const T* ptr;
48
- explicit ScopeDeleter(const T* ptr = nullptr) : ptr(ptr) {}
49
- void release() {
50
- ptr = nullptr;
51
- }
52
- void set(const T* ptr_in) {
53
- ptr = ptr_in;
54
- }
55
- void swap(ScopeDeleter<T>& other) {
56
- std::swap(ptr, other.ptr);
57
- }
58
- ~ScopeDeleter() {
59
- delete[] ptr;
44
+ /** RAII object for a set of possibly transformed vectors (deallocated only if
45
+ * they are indeed transformed)
46
+ */
47
+ struct TransformedVectors {
48
+ const float* x;
49
+ bool own_x;
50
+ TransformedVectors(const float* x_orig, const float* x) : x(x) {
51
+ own_x = x_orig != x;
60
52
  }
61
- };
62
53
 
63
- /** same but deletes with the simple delete (least common case) */
64
- template <class T>
65
- struct ScopeDeleter1 {
66
- const T* ptr;
67
- explicit ScopeDeleter1(const T* ptr = nullptr) : ptr(ptr) {}
68
- void release() {
69
- ptr = nullptr;
70
- }
71
- void set(const T* ptr_in) {
72
- ptr = ptr_in;
73
- }
74
- void swap(ScopeDeleter1<T>& other) {
75
- std::swap(ptr, other.ptr);
76
- }
77
- ~ScopeDeleter1() {
78
- delete ptr;
54
+ ~TransformedVectors() {
55
+ if (own_x) {
56
+ delete[] x;
57
+ }
79
58
  }
80
59
  };
81
60