faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -5,14 +5,12 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexIVF.h>
11
9
 
12
10
  #include <omp.h>
11
+ #include <atomic>
13
12
  #include <cstdint>
14
13
  #include <memory>
15
- #include <mutex>
16
14
 
17
15
  #include <algorithm>
18
16
  #include <cinttypes>
@@ -27,6 +25,8 @@
27
25
  #include <faiss/impl/CodePacker.h>
28
26
  #include <faiss/impl/FaissAssert.h>
29
27
  #include <faiss/impl/IDSelector.h>
28
+ #include <faiss/impl/ResultHandler.h>
29
+ #include <faiss/impl/expanded_scanners.h>
30
30
 
31
31
  namespace faiss {
32
32
 
@@ -37,8 +37,8 @@ using ScopedCodes = InvertedLists::ScopedCodes;
37
37
  * Level1Quantizer implementation
38
38
  ******************************************/
39
39
 
40
- Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
41
- : quantizer(quantizer), nlist(nlist) {
40
+ Level1Quantizer::Level1Quantizer(Index* quantizer_in, size_t nlist_in)
41
+ : quantizer(quantizer_in), nlist(nlist_in) {
42
42
  // here we set a low # iterations because this is typically used
43
43
  // for large clusterings (nb this is not used for the MultiIndex,
44
44
  // for which quantizer_trains_alone = true)
@@ -58,8 +58,10 @@ void Level1Quantizer::train_q1(
58
58
  const float* x,
59
59
  bool verbose,
60
60
  MetricType metric_type) {
61
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
61
62
  size_t d = quantizer->d;
62
- if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
63
+ if (quantizer->is_trained &&
64
+ (static_cast<size_t>(quantizer->ntotal) == nlist)) {
63
65
  if (verbose) {
64
66
  printf("IVF quantizer does not need training.\n");
65
67
  }
@@ -70,14 +72,14 @@ void Level1Quantizer::train_q1(
70
72
  quantizer->verbose = verbose;
71
73
  quantizer->train(n, x);
72
74
  FAISS_THROW_IF_NOT_MSG(
73
- quantizer->ntotal == nlist,
75
+ static_cast<size_t>(quantizer->ntotal) == nlist,
74
76
  "nlist not consistent with quantizer size");
75
77
  } else if (quantizer_trains_alone == 0) {
76
78
  if (verbose) {
77
79
  printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
78
80
  }
79
81
 
80
- Clustering clus(d, nlist, cp);
82
+ Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
81
83
  quantizer->reset();
82
84
  if (clustering_index) {
83
85
  clus.train(n, x, *clustering_index);
@@ -99,7 +101,7 @@ void Level1Quantizer::train_q1(
99
101
  metric_type == METRIC_L2 ||
100
102
  (metric_type == METRIC_INNER_PRODUCT && cp.spherical));
101
103
 
102
- Clustering clus(d, nlist, cp);
104
+ Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
103
105
  if (!clustering_index) {
104
106
  IndexFlatL2 assigner(d);
105
107
  clus.train(n, x, assigner);
@@ -148,7 +150,7 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
148
150
  nbit += 8;
149
151
  nl >>= 8;
150
152
  }
151
- FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
153
+ FAISS_THROW_IF_NOT(list_no >= 0 && static_cast<size_t>(list_no) < nlist);
152
154
  return list_no;
153
155
  }
154
156
 
@@ -157,21 +159,23 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
157
159
  ******************************************/
158
160
 
159
161
  IndexIVF::IndexIVF(
160
- Index* quantizer,
161
- size_t d,
162
- size_t nlist,
163
- size_t code_size,
162
+ Index* quantizer_in,
163
+ size_t d_in,
164
+ size_t nlist_in,
165
+ size_t code_size_in,
164
166
  MetricType metric,
165
- bool own_invlists)
166
- : Index(d, metric),
167
- IndexIVFInterface(quantizer, nlist),
167
+ bool own_invlists_in)
168
+ : Index(d_in, metric),
169
+ IndexIVFInterface(quantizer_in, nlist_in),
168
170
  invlists(
169
- own_invlists ? new ArrayInvertedLists(nlist, code_size)
170
- : nullptr),
171
- own_invlists(own_invlists),
172
- code_size(code_size) {
173
- FAISS_THROW_IF_NOT(d == quantizer->d);
174
- is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
171
+ own_invlists_in
172
+ ? new ArrayInvertedLists(nlist_in, code_size_in)
173
+ : nullptr),
174
+ own_invlists(own_invlists_in),
175
+ code_size(code_size_in) {
176
+ FAISS_THROW_IF_NOT(static_cast<int>(d_in) == quantizer_in->d);
177
+ is_trained = quantizer_in->is_trained &&
178
+ (static_cast<size_t>(quantizer_in->ntotal) == nlist_in);
175
179
  // Spherical by default if the metric is inner_product
176
180
  if (metric_type == METRIC_INNER_PRODUCT) {
177
181
  cp.spherical = true;
@@ -185,6 +189,8 @@ void IndexIVF::add(idx_t n, const float* x) {
185
189
  }
186
190
 
187
191
  void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
192
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
193
+ FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
188
194
  std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
189
195
  quantizer->assign(n, x, coarse_idx.get());
190
196
  add_core(n, x, xids, coarse_idx.get());
@@ -235,7 +241,7 @@ void IndexIVF::add_core(
235
241
 
236
242
  size_t nadd = 0, nminus1 = 0;
237
243
 
238
- for (size_t i = 0; i < n; i++) {
244
+ for (idx_t i = 0; i < n; i++) {
239
245
  if (coarse_idx[i] < 0) {
240
246
  nminus1++;
241
247
  }
@@ -252,7 +258,7 @@ void IndexIVF::add_core(
252
258
  int rank = omp_get_thread_num();
253
259
 
254
260
  // each thread takes care of a subset of lists
255
- for (size_t i = 0; i < n; i++) {
261
+ for (idx_t i = 0; i < n; i++) {
256
262
  idx_t list_no = coarse_idx[i];
257
263
  if (list_no >= 0 && list_no % nt == rank) {
258
264
  idx_t id = xids ? xids[i] : ntotal + i;
@@ -305,45 +311,49 @@ void IndexIVF::search(
305
311
  idx_t* labels,
306
312
  const SearchParameters* params_in) const {
307
313
  FAISS_THROW_IF_NOT(k > 0);
314
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
315
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
316
+ FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
308
317
  const IVFSearchParameters* params = nullptr;
309
318
  if (params_in) {
310
319
  params = dynamic_cast<const IVFSearchParameters*>(params_in);
311
320
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
312
321
  }
313
- const size_t nprobe =
322
+ const size_t cur_nprobe =
314
323
  std::min(nlist, params ? params->nprobe : this->nprobe);
315
- FAISS_THROW_IF_NOT(nprobe > 0);
324
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
316
325
 
317
326
  // search function for a subset of queries
318
- auto sub_search_func = [this, k, nprobe, params](
319
- idx_t n,
320
- const float* x,
321
- float* distances,
322
- idx_t* labels,
327
+ auto sub_search_func = [this, k, cur_nprobe, params](
328
+ idx_t sub_n,
329
+ const float* sub_x,
330
+ float* sub_distances,
331
+ idx_t* sub_labels,
323
332
  IndexIVFStats* ivf_stats) {
324
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
325
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
333
+ std::unique_ptr<idx_t[]> idx(new idx_t[sub_n * cur_nprobe]);
334
+ std::unique_ptr<float[]> coarse_dis(new float[sub_n * cur_nprobe]);
326
335
 
327
336
  double t0 = getmillisecs();
328
337
  quantizer->search(
329
- n,
330
- x,
331
- nprobe,
338
+ sub_n,
339
+ sub_x,
340
+ cur_nprobe,
332
341
  coarse_dis.get(),
333
342
  idx.get(),
334
343
  params ? params->quantizer_params : nullptr);
335
344
 
336
345
  double t1 = getmillisecs();
337
- invlists->prefetch_lists(idx.get(), n * nprobe);
346
+ invlists->prefetch_lists(
347
+ idx.get(), static_cast<int>(sub_n * cur_nprobe));
338
348
 
339
349
  search_preassigned(
340
- n,
341
- x,
350
+ sub_n,
351
+ sub_x,
342
352
  k,
343
353
  idx.get(),
344
354
  coarse_dis.get(),
345
- distances,
346
- labels,
355
+ sub_distances,
356
+ sub_labels,
347
357
  false,
348
358
  params,
349
359
  ivf_stats);
@@ -355,32 +365,28 @@ void IndexIVF::search(
355
365
  if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
356
366
  int nt = std::min(omp_get_max_threads(), int(n));
357
367
  std::vector<IndexIVFStats> stats(nt);
358
- std::mutex exception_mutex;
359
- std::string exception_string;
368
+ std::exception_ptr ex;
360
369
 
361
370
  #pragma omp parallel for if (nt > 1)
362
371
  for (idx_t slice = 0; slice < nt; slice++) {
363
- IndexIVFStats local_stats;
364
- idx_t i0 = n * slice / nt;
365
- idx_t i1 = n * (slice + 1) / nt;
366
- if (i1 > i0) {
367
- try {
372
+ try {
373
+ IndexIVFStats local_stats;
374
+ idx_t i0 = n * slice / nt;
375
+ idx_t i1 = n * (slice + 1) / nt;
376
+ if (i1 > i0) {
368
377
  sub_search_func(
369
378
  i1 - i0,
370
379
  x + i0 * d,
371
380
  distances + i0 * k,
372
381
  labels + i0 * k,
373
382
  &stats[slice]);
374
- } catch (const std::exception& e) {
375
- std::lock_guard<std::mutex> lock(exception_mutex);
376
- exception_string = e.what();
377
383
  }
384
+ } catch (...) {
385
+ omp_capture_exception(ex);
378
386
  }
379
387
  }
380
388
 
381
- if (!exception_string.empty()) {
382
- FAISS_THROW_MSG(exception_string.c_str());
383
- }
389
+ omp_rethrow_if_exception(ex);
384
390
 
385
391
  // collect stats
386
392
  for (idx_t slice = 0; slice < nt; slice++) {
@@ -405,13 +411,17 @@ void IndexIVF::search_preassigned(
405
411
  const IVFSearchParameters* params,
406
412
  IndexIVFStats* ivf_stats) const {
407
413
  FAISS_THROW_IF_NOT(k > 0);
414
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
415
+ FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
408
416
 
409
- idx_t nprobe = params ? params->nprobe : this->nprobe;
410
- nprobe = std::min((idx_t)nlist, nprobe);
411
- FAISS_THROW_IF_NOT(nprobe > 0);
417
+ idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
418
+ cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
419
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
412
420
 
413
421
  const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
414
- idx_t max_codes = params ? params->max_codes : this->max_codes;
422
+ idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
423
+ const bool ensure_topk_full = params ? params->ensure_topk_full : false;
424
+
415
425
  IDSelector* sel = params ? params->sel : nullptr;
416
426
  const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
417
427
  if (selr) {
@@ -427,7 +437,8 @@ void IndexIVF::search_preassigned(
427
437
  "selector and store_pairs cannot be combined");
428
438
 
429
439
  FAISS_THROW_IF_NOT_MSG(
430
- !invlists->use_iterator || (max_codes == 0 && store_pairs == false),
440
+ !invlists->use_iterator ||
441
+ (cur_max_codes == 0 && store_pairs == false),
431
442
  "iterable inverted lists don't support max_codes and store_pairs");
432
443
 
433
444
  size_t nlistv = 0, ndis = 0, nheap = 0;
@@ -435,106 +446,119 @@ void IndexIVF::search_preassigned(
435
446
  using HeapForIP = CMin<float, idx_t>;
436
447
  using HeapForL2 = CMax<float, idx_t>;
437
448
 
438
- bool interrupt = false;
439
- std::mutex exception_mutex;
440
- std::string exception_string;
449
+ std::exception_ptr ex;
450
+ std::atomic<bool> interrupt{false};
441
451
 
442
452
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
443
453
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
444
454
 
445
455
  FAISS_THROW_IF_NOT_MSG(
446
- max_codes == 0 || pmode == 0 || pmode == 3,
456
+ cur_max_codes == 0 || pmode == 0 || pmode == 3,
447
457
  "max_codes supported only for parallel_mode = 0 or 3");
448
458
 
449
- if (max_codes == 0) {
450
- max_codes = unlimited_list_size;
459
+ FAISS_THROW_IF_NOT_MSG(
460
+ !ensure_topk_full || pmode == 0 || pmode == 3,
461
+ "ensure_topk_full supported only for parallel_mode = 0 or 3");
462
+
463
+ if (cur_max_codes == 0) {
464
+ cur_max_codes = unlimited_list_size;
451
465
  }
466
+ // Budget used by the probe loop below. ensure_topk_full makes a small
467
+ // max_codes budget large enough to give k post-filter candidates a chance.
468
+ idx_t effective_max_codes =
469
+ ensure_topk_full ? std::max(cur_max_codes, k) : cur_max_codes;
452
470
 
453
471
  [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
454
472
  (pmode == 0 ? false
455
473
  : pmode == 3 ? n > 1
456
- : pmode == 1 ? nprobe > 1
457
- : nprobe * n > 1);
474
+ : pmode == 1 ? cur_nprobe > 1
475
+ : cur_nprobe * n > 1);
458
476
 
459
477
  void* inverted_list_context =
460
478
  params ? params->inverted_list_context : nullptr;
461
479
 
462
480
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
463
481
  {
464
- std::unique_ptr<InvertedListScanner> scanner(
465
- get_InvertedListScanner(store_pairs, sel, params));
466
-
467
- /*****************************************************
468
- * Depending on parallel_mode, there are two possible ways
469
- * to organize the search. Here we define local functions
470
- * that are in common between the two
471
- ******************************************************/
472
-
473
- // initialize + reorder a result heap
482
+ // C++ exceptions that escape an OpenMP parallel region without being
483
+ // caught inside it call std::terminate — they cannot propagate across
484
+ // thread boundaries. The outer try/catch covers per-thread setup
485
+ // (scanner creation, set_query); the inner try/catch in scan_one_list
486
+ // covers per-list operations. Both set interrupt=true to stop further
487
+ // work and re-throw after the parallel region exits.
488
+ try {
489
+ std::unique_ptr<InvertedListScanner> scanner(
490
+ get_InvertedListScanner(store_pairs, sel, params));
491
+
492
+ /*****************************************************
493
+ * Depending on parallel_mode, there are two possible ways
494
+ * to organize the search. Here we define local functions
495
+ * that are in common between the two
496
+ ******************************************************/
497
+
498
+ // initialize + reorder a result heap
499
+
500
+ auto init_result = [&](float* simi, idx_t* idxi) {
501
+ if (!do_heap_init) {
502
+ return;
503
+ }
504
+ if (metric_type == METRIC_INNER_PRODUCT) {
505
+ heap_heapify<HeapForIP>(k, simi, idxi);
506
+ } else {
507
+ heap_heapify<HeapForL2>(k, simi, idxi);
508
+ }
509
+ };
510
+
511
+ auto add_local_results = [&](const float* local_dis,
512
+ const idx_t* local_idx,
513
+ float* simi,
514
+ idx_t* idxi) {
515
+ if (metric_type == METRIC_INNER_PRODUCT) {
516
+ heap_addn<HeapForIP>(
517
+ k, simi, idxi, local_dis, local_idx, k);
518
+ } else {
519
+ heap_addn<HeapForL2>(
520
+ k, simi, idxi, local_dis, local_idx, k);
521
+ }
522
+ };
474
523
 
475
- auto init_result = [&](float* simi, idx_t* idxi) {
476
- if (!do_heap_init) {
477
- return;
478
- }
479
- if (metric_type == METRIC_INNER_PRODUCT) {
480
- heap_heapify<HeapForIP>(k, simi, idxi);
481
- } else {
482
- heap_heapify<HeapForL2>(k, simi, idxi);
483
- }
484
- };
524
+ auto reorder_result = [&](float* simi, idx_t* idxi) {
525
+ if (!do_heap_init) {
526
+ return;
527
+ }
528
+ if (metric_type == METRIC_INNER_PRODUCT) {
529
+ heap_reorder<HeapForIP>(k, simi, idxi);
530
+ } else {
531
+ heap_reorder<HeapForL2>(k, simi, idxi);
532
+ }
533
+ };
485
534
 
486
- auto add_local_results = [&](const float* local_dis,
487
- const idx_t* local_idx,
535
+ // single list scan using the current scanner (with query
536
+ // set properly) and storing results in simi and idxi
537
+ auto scan_one_list = [&](idx_t key,
538
+ float coarse_dis_i,
488
539
  float* simi,
489
- idx_t* idxi) {
490
- if (metric_type == METRIC_INNER_PRODUCT) {
491
- heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
492
- } else {
493
- heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
494
- }
495
- };
496
-
497
- auto reorder_result = [&](float* simi, idx_t* idxi) {
498
- if (!do_heap_init) {
499
- return;
500
- }
501
- if (metric_type == METRIC_INNER_PRODUCT) {
502
- heap_reorder<HeapForIP>(k, simi, idxi);
503
- } else {
504
- heap_reorder<HeapForL2>(k, simi, idxi);
505
- }
506
- };
507
-
508
- // single list scan using the current scanner (with query
509
- // set properly) and storing results in simi and idxi
510
- auto scan_one_list = [&](idx_t key,
511
- float coarse_dis_i,
512
- float* simi,
513
- idx_t* idxi,
514
- idx_t list_size_max) {
515
- if (key < 0) {
516
- // not enough centroids for multiprobe
517
- return (size_t)0;
518
- }
519
- FAISS_THROW_IF_NOT_FMT(
520
- key < (idx_t)nlist,
521
- "Invalid key=%" PRId64 " nlist=%zd\n",
522
- key,
523
- nlist);
524
-
525
- // don't waste time on empty lists
526
- if (invlists->is_empty(key, inverted_list_context)) {
527
- return (size_t)0;
528
- }
529
-
530
- scanner->set_list(key, coarse_dis_i);
540
+ idx_t* idxi,
541
+ idx_t list_size_max) {
542
+ if (key < 0) {
543
+ // not enough centroids for multiprobe
544
+ return (size_t)0;
545
+ }
546
+ FAISS_THROW_IF_NOT_FMT(
547
+ key < (idx_t)nlist,
548
+ "Invalid key=%" PRId64 " nlist=%zd\n",
549
+ key,
550
+ nlist);
551
+
552
+ // don't waste time on empty lists
553
+ if (invlists->is_empty(key, inverted_list_context)) {
554
+ return (size_t)0;
555
+ }
531
556
 
532
- nlistv++;
557
+ scanner->set_list(key, coarse_dis_i);
533
558
 
534
- try {
559
+ nlistv++;
535
560
  if (invlists->use_iterator) {
536
561
  size_t list_size = 0;
537
-
538
562
  std::unique_ptr<InvertedListsIterator> it(
539
563
  invlists->get_iterator(key, inverted_list_context));
540
564
 
@@ -544,8 +568,8 @@ void IndexIVF::search_preassigned(
544
568
  return list_size;
545
569
  } else {
546
570
  size_t list_size = invlists->list_size(key);
547
- if (list_size > list_size_max) {
548
- list_size = list_size_max;
571
+ if (list_size > static_cast<size_t>(list_size_max)) {
572
+ list_size = static_cast<size_t>(list_size_max);
549
573
  }
550
574
 
551
575
  InvertedLists::ScopedCodes scodes(invlists, key);
@@ -573,144 +597,167 @@ void IndexIVF::search_preassigned(
573
597
  ids += jmin;
574
598
  }
575
599
 
576
- nheap += scanner->scan_codes(
577
- list_size, codes, ids, simi, idxi, k);
578
-
579
- return list_size;
600
+ size_t old_scan_cnt = 0;
601
+ size_t old_heap_updates = 0;
602
+ if (metric_type == METRIC_INNER_PRODUCT) {
603
+ HeapResultHandler<HeapForIP, false> handler(
604
+ k, simi, idxi);
605
+ old_scan_cnt = handler.stats.scan_cnt;
606
+ old_heap_updates = handler.stats.nheap_updates;
607
+ scanner->scan_codes(list_size, codes, ids, handler);
608
+ nheap += handler.stats.nheap_updates - old_heap_updates;
609
+ return handler.stats.scan_cnt - old_scan_cnt;
610
+ } else {
611
+ HeapResultHandler<HeapForL2, false> handler(
612
+ k, simi, idxi);
613
+ old_scan_cnt = handler.stats.scan_cnt;
614
+ old_heap_updates = handler.stats.nheap_updates;
615
+ scanner->scan_codes(list_size, codes, ids, handler);
616
+ nheap += handler.stats.nheap_updates - old_heap_updates;
617
+ return handler.stats.scan_cnt - old_scan_cnt;
618
+ }
580
619
  }
581
- } catch (const std::exception& e) {
582
- std::lock_guard<std::mutex> lock(exception_mutex);
583
- exception_string =
584
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
585
- interrupt = true;
586
- return size_t(0);
587
- }
588
- };
620
+ };
589
621
 
590
- /****************************************************
591
- * Actual loops, depending on parallel_mode
592
- ****************************************************/
622
+ /****************************************************
623
+ * Actual loops, depending on parallel_mode
624
+ ****************************************************/
593
625
 
594
- if (pmode == 0 || pmode == 3) {
626
+ if (pmode == 0 || pmode == 3) {
595
627
  #pragma omp for
596
- for (idx_t i = 0; i < n; i++) {
597
- if (interrupt) {
598
- continue;
599
- }
600
-
601
- // loop over queries
602
- scanner->set_query(x + i * d);
603
- float* simi = distances + i * k;
604
- idx_t* idxi = labels + i * k;
605
-
606
- init_result(simi, idxi);
607
-
608
- idx_t nscan = 0;
609
-
610
- // loop over probes
611
- for (size_t ik = 0; ik < nprobe; ik++) {
612
- nscan += scan_one_list(
613
- keys[i * nprobe + ik],
614
- coarse_dis[i * nprobe + ik],
615
- simi,
616
- idxi,
617
- max_codes - nscan);
618
- if (nscan >= max_codes) {
619
- break;
628
+ for (idx_t i = 0; i < n; i++) {
629
+ if (interrupt.load(std::memory_order_relaxed)) {
630
+ continue;
620
631
  }
621
- }
632
+ try {
633
+ // loop over queries
634
+ scanner->set_query(x + i * d);
635
+ float* simi = distances + i * k;
636
+ idx_t* idxi = labels + i * k;
637
+
638
+ init_result(simi, idxi);
639
+
640
+ idx_t nscan = 0;
641
+
642
+ // loop over probes
643
+ for (idx_t ik = 0; ik < cur_nprobe; ik++) {
644
+ // For soft budgets, scan whole lists so
645
+ // IDSelector-filtered rows do not consume the
646
+ // remaining code budget.
647
+ const idx_t list_size_max = ensure_topk_full
648
+ ? unlimited_list_size
649
+ : effective_max_codes - nscan;
650
+ nscan += scan_one_list(
651
+ keys[i * cur_nprobe + ik],
652
+ coarse_dis[i * cur_nprobe + ik],
653
+ simi,
654
+ idxi,
655
+ list_size_max);
656
+
657
+ // Early-stop check: apply max_codes after each
658
+ // list. nscan is the number of distances
659
+ // actually computed.
660
+ if (nscan >= effective_max_codes) {
661
+ break;
662
+ }
663
+ }
622
664
 
623
- ndis += nscan;
624
- reorder_result(simi, idxi);
665
+ ndis += nscan;
666
+ reorder_result(simi, idxi);
625
667
 
626
- if (InterruptCallback::is_interrupted()) {
627
- interrupt = true;
628
- }
629
-
630
- } // parallel for
631
- } else if (pmode == 1) {
632
- std::vector<idx_t> local_idx(k);
633
- std::vector<float> local_dis(k);
668
+ InterruptCallback::check();
669
+ } catch (...) {
670
+ omp_capture_exception(ex, [&] { interrupt = true; });
671
+ }
672
+ } // parallel for
673
+ } else if (pmode == 1) {
674
+ std::vector<idx_t> local_idx(k);
675
+ std::vector<float> local_dis(k);
634
676
 
635
- for (size_t i = 0; i < n; i++) {
636
- scanner->set_query(x + i * d);
637
- init_result(local_dis.data(), local_idx.data());
677
+ for (idx_t i = 0; i < n; i++) {
678
+ scanner->set_query(x + i * d);
679
+ init_result(local_dis.data(), local_idx.data());
638
680
 
639
681
  #pragma omp for schedule(dynamic)
640
- for (idx_t ik = 0; ik < nprobe; ik++) {
641
- ndis += scan_one_list(
642
- keys[i * nprobe + ik],
643
- coarse_dis[i * nprobe + ik],
644
- local_dis.data(),
645
- local_idx.data(),
646
- unlimited_list_size);
647
-
648
- // can't do the test on max_codes
649
- }
650
- // merge thread-local results
682
+ for (idx_t ik = 0; ik < cur_nprobe; ik++) {
683
+ try {
684
+ ndis += scan_one_list(
685
+ keys[i * cur_nprobe + ik],
686
+ coarse_dis[i * cur_nprobe + ik],
687
+ local_dis.data(),
688
+ local_idx.data(),
689
+ unlimited_list_size);
690
+
691
+ // can't do the test on max_codes
692
+ } catch (...) {
693
+ omp_capture_exception(
694
+ ex, [&] { interrupt = true; });
695
+ }
696
+ }
697
+ // merge thread-local results
651
698
 
652
- float* simi = distances + i * k;
653
- idx_t* idxi = labels + i * k;
699
+ float* simi = distances + i * k;
700
+ idx_t* idxi = labels + i * k;
654
701
  #pragma omp single
655
- init_result(simi, idxi);
702
+ init_result(simi, idxi);
656
703
 
657
704
  #pragma omp barrier
658
705
  #pragma omp critical
659
- {
660
- add_local_results(
661
- local_dis.data(), local_idx.data(), simi, idxi);
662
- }
706
+ {
707
+ add_local_results(
708
+ local_dis.data(), local_idx.data(), simi, idxi);
709
+ }
663
710
  #pragma omp barrier
664
711
  #pragma omp single
665
- reorder_result(simi, idxi);
666
- }
667
- } else if (pmode == 2) {
668
- std::vector<idx_t> local_idx(k);
669
- std::vector<float> local_dis(k);
712
+ reorder_result(simi, idxi);
713
+ }
714
+ } else if (pmode == 2) {
715
+ std::vector<idx_t> local_idx(k);
716
+ std::vector<float> local_dis(k);
670
717
 
671
718
  #pragma omp single
672
- for (int64_t i = 0; i < n; i++) {
673
- init_result(distances + i * k, labels + i * k);
674
- }
719
+ for (int64_t i = 0; i < n; i++) {
720
+ init_result(distances + i * k, labels + i * k);
721
+ }
675
722
 
676
723
  #pragma omp for schedule(dynamic)
677
- for (int64_t ij = 0; ij < n * nprobe; ij++) {
678
- size_t i = ij / nprobe;
679
-
680
- scanner->set_query(x + i * d);
681
- init_result(local_dis.data(), local_idx.data());
682
- ndis += scan_one_list(
683
- keys[ij],
684
- coarse_dis[ij],
685
- local_dis.data(),
686
- local_idx.data(),
687
- unlimited_list_size);
724
+ for (int64_t ij = 0; ij < n * cur_nprobe; ij++) {
725
+ try {
726
+ size_t i = ij / cur_nprobe;
727
+
728
+ scanner->set_query(x + i * d);
729
+ init_result(local_dis.data(), local_idx.data());
730
+ ndis += scan_one_list(
731
+ keys[ij],
732
+ coarse_dis[ij],
733
+ local_dis.data(),
734
+ local_idx.data(),
735
+ unlimited_list_size);
688
736
  #pragma omp critical
689
- {
690
- add_local_results(
691
- local_dis.data(),
692
- local_idx.data(),
693
- distances + i * k,
694
- labels + i * k);
737
+ {
738
+ add_local_results(
739
+ local_dis.data(),
740
+ local_idx.data(),
741
+ distances + i * k,
742
+ labels + i * k);
743
+ }
744
+ } catch (...) {
745
+ omp_capture_exception(ex, [&] { interrupt = true; });
746
+ }
695
747
  }
696
- }
697
748
  #pragma omp single
698
- for (int64_t i = 0; i < n; i++) {
699
- reorder_result(distances + i * k, labels + i * k);
749
+ for (int64_t i = 0; i < n; i++) {
750
+ reorder_result(distances + i * k, labels + i * k);
751
+ }
752
+ } else {
753
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
700
754
  }
701
- } else {
702
- FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
755
+ } catch (...) {
756
+ omp_capture_exception(ex, [&] { interrupt = true; });
703
757
  }
704
758
  } // parallel section
705
759
 
706
- if (interrupt) {
707
- if (!exception_string.empty()) {
708
- FAISS_THROW_FMT(
709
- "search interrupted with: %s", exception_string.c_str());
710
- } else {
711
- FAISS_THROW_MSG("computation interrupted");
712
- }
713
- }
760
+ omp_rethrow_if_exception(ex);
714
761
 
715
762
  if (ivf_stats == nullptr) {
716
763
  ivf_stats = &indexIVF_stats;
@@ -727,6 +774,8 @@ void IndexIVF::range_search(
727
774
  float radius,
728
775
  RangeSearchResult* result,
729
776
  const SearchParameters* params_in) const {
777
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
778
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
730
779
  const IVFSearchParameters* params = nullptr;
731
780
  const SearchParameters* quantizer_params = nullptr;
732
781
  if (params_in) {
@@ -734,18 +783,18 @@ void IndexIVF::range_search(
734
783
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
735
784
  quantizer_params = params->quantizer_params;
736
785
  }
737
- const size_t nprobe =
786
+ const size_t cur_nprobe =
738
787
  std::min(nlist, params ? params->nprobe : this->nprobe);
739
- std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
740
- std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
788
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
789
+ std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
741
790
 
742
791
  double t0 = getmillisecs();
743
792
  quantizer->search(
744
- nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
793
+ nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
745
794
  indexIVF_stats.quantization_time += getmillisecs() - t0;
746
795
 
747
796
  t0 = getmillisecs();
748
- invlists->prefetch_lists(keys.get(), nx * nprobe);
797
+ invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
749
798
 
750
799
  range_search_preassigned(
751
800
  nx,
@@ -771,22 +820,29 @@ void IndexIVF::range_search_preassigned(
771
820
  bool store_pairs,
772
821
  const IVFSearchParameters* params,
773
822
  IndexIVFStats* stats) const {
774
- idx_t nprobe = params ? params->nprobe : this->nprobe;
775
- nprobe = std::min((idx_t)nlist, nprobe);
776
- FAISS_THROW_IF_NOT(nprobe > 0);
777
-
778
- idx_t max_codes = params ? params->max_codes : this->max_codes;
823
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
824
+ idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
825
+ cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
826
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
827
+
828
+ idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
829
+ // Range-search early-stop budget. 0 disables the empty-bucket stop.
830
+ const size_t max_empty_result_buckets =
831
+ params ? params->max_empty_result_buckets : 0;
779
832
  IDSelector* sel = params ? params->sel : nullptr;
780
833
 
781
834
  FAISS_THROW_IF_NOT_MSG(
782
- !invlists->use_iterator || (max_codes == 0 && store_pairs == false),
835
+ !invlists->use_iterator ||
836
+ (cur_max_codes == 0 && store_pairs == false),
783
837
  "iterable inverted lists don't support max_codes and store_pairs");
784
838
 
839
+ FAISS_THROW_IF_NOT_MSG(
840
+ max_empty_result_buckets == 0 || parallel_mode == 0,
841
+ "max_empty_result_buckets supported only for parallel_mode = 0");
842
+
785
843
  size_t nlistv = 0, ndis = 0;
786
844
 
787
- bool interrupt = false;
788
- std::mutex exception_mutex;
789
- std::string exception_string;
845
+ std::exception_ptr ex;
790
846
 
791
847
  std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
792
848
 
@@ -795,122 +851,142 @@ void IndexIVF::range_search_preassigned(
795
851
  [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
796
852
  (pmode == 3 ? false
797
853
  : pmode == 0 ? nx > 1
798
- : pmode == 1 ? nprobe > 1
799
- : nprobe * nx > 1);
854
+ : pmode == 1 ? cur_nprobe > 1
855
+ : cur_nprobe * nx > 1);
800
856
 
801
857
  void* inverted_list_context =
802
858
  params ? params->inverted_list_context : nullptr;
803
859
 
804
860
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
805
861
  {
806
- RangeSearchPartialResult pres(result);
807
- std::unique_ptr<InvertedListScanner> scanner(
808
- get_InvertedListScanner(store_pairs, sel, params));
809
- FAISS_THROW_IF_NOT(scanner.get());
810
- all_pres[omp_get_thread_num()] = &pres;
811
-
812
- // prepare the list scanning function
813
-
814
- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
815
- idx_t key = keys[i * nprobe + ik]; /* select the list */
816
- if (key < 0) {
817
- return;
818
- }
819
- FAISS_THROW_IF_NOT_FMT(
820
- key < (idx_t)nlist,
821
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
822
- key,
823
- ik,
824
- nlist);
825
-
826
- if (invlists->is_empty(key, inverted_list_context)) {
827
- return;
828
- }
862
+ try {
863
+ RangeSearchPartialResult pres(result);
864
+ std::unique_ptr<InvertedListScanner> scanner(
865
+ get_InvertedListScanner(store_pairs, sel, params));
866
+ FAISS_THROW_IF_NOT(scanner.get());
867
+ all_pres[omp_get_thread_num()] = &pres;
868
+
869
+ // prepare the list scanning function
870
+
871
+ auto scan_list_func = [&](size_t i,
872
+ size_t ik,
873
+ RangeQueryResult& qres) {
874
+ idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
875
+ if (key < 0) {
876
+ return;
877
+ }
829
878
 
830
- try {
831
- size_t list_size = 0;
832
- scanner->set_list(key, coarse_dis[i * nprobe + ik]);
879
+ FAISS_THROW_IF_NOT_FMT(
880
+ key < (idx_t)nlist,
881
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
882
+ key,
883
+ ik,
884
+ nlist);
885
+
886
+ if (invlists->is_empty(key, inverted_list_context)) {
887
+ return;
888
+ }
889
+
890
+ scanner->set_list(key, coarse_dis[i * cur_nprobe + ik]);
891
+ const size_t scan_cnt0 = qres.stats.scan_cnt;
833
892
  if (invlists->use_iterator) {
893
+ size_t list_size = 0;
834
894
  std::unique_ptr<InvertedListsIterator> it(
835
895
  invlists->get_iterator(key, inverted_list_context));
836
896
 
837
897
  scanner->iterate_codes_range(
838
898
  it.get(), radius, qres, list_size);
899
+ qres.stats.scan_cnt += list_size;
839
900
  } else {
840
901
  InvertedLists::ScopedCodes scodes(invlists, key);
841
902
  InvertedLists::ScopedIds ids(invlists, key);
842
- list_size = invlists->list_size(key);
903
+ size_t list_size = invlists->list_size(key);
843
904
 
844
905
  scanner->scan_codes_range(
845
906
  list_size, scodes.get(), ids.get(), radius, qres);
846
907
  }
847
908
  nlistv++;
848
- ndis += list_size;
849
- } catch (const std::exception& e) {
850
- std::lock_guard<std::mutex> lock(exception_mutex);
851
- exception_string =
852
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
853
- interrupt = true;
854
- }
855
- };
909
+ ndis += qres.stats.scan_cnt - scan_cnt0;
910
+ };
856
911
 
857
- if (parallel_mode == 0) {
912
+ if (parallel_mode == 0) {
858
913
  #pragma omp for
859
- for (idx_t i = 0; i < nx; i++) {
860
- scanner->set_query(x + i * d);
861
-
862
- RangeQueryResult& qres = pres.new_result(i);
863
-
864
- for (size_t ik = 0; ik < nprobe; ik++) {
865
- scan_list_func(i, ik, qres);
914
+ for (idx_t i = 0; i < nx; i++) {
915
+ try {
916
+ scanner->set_query(x + i * d);
917
+ RangeQueryResult& qres = pres.new_result(i);
918
+
919
+ // Stop after enough consecutive probes add no range
920
+ // results. A hit resets the counter.
921
+ size_t prev_nres = qres.nres;
922
+ size_t ndup = 0;
923
+ for (idx_t ik = 0; ik < cur_nprobe; ik++) {
924
+ scan_list_func(i, ik, qres);
925
+ if (max_empty_result_buckets > 0) {
926
+ // Early-stop check: stop range search after
927
+ // enough consecutive empty probes.
928
+ ndup = (qres.nres == prev_nres) ? ndup + 1 : 0;
929
+ if (ndup >= max_empty_result_buckets) {
930
+ break;
931
+ }
932
+ prev_nres = qres.nres;
933
+ }
934
+ }
935
+ } catch (...) {
936
+ omp_capture_exception(ex);
937
+ }
866
938
  }
867
- }
868
939
 
869
- } else if (parallel_mode == 1) {
870
- for (size_t i = 0; i < nx; i++) {
871
- scanner->set_query(x + i * d);
940
+ } else if (parallel_mode == 1) {
941
+ for (idx_t i = 0; i < nx; i++) {
942
+ scanner->set_query(x + i * d);
872
943
 
873
- RangeQueryResult& qres = pres.new_result(i);
944
+ RangeQueryResult& qres = pres.new_result(i);
874
945
 
875
946
  #pragma omp for schedule(dynamic)
876
- for (int64_t ik = 0; ik < nprobe; ik++) {
877
- scan_list_func(i, ik, qres);
947
+ for (int64_t ik = 0; ik < cur_nprobe; ik++) {
948
+ try {
949
+ scan_list_func(i, ik, qres);
950
+ } catch (...) {
951
+ omp_capture_exception(ex);
952
+ }
953
+ }
878
954
  }
879
- }
880
- } else if (parallel_mode == 2) {
881
- RangeQueryResult* qres = nullptr;
955
+ } else if (parallel_mode == 2) {
956
+ RangeQueryResult* qres = nullptr;
882
957
 
883
958
  #pragma omp for schedule(dynamic)
884
- for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
885
- idx_t i = iik / (idx_t)nprobe;
886
- idx_t ik = iik % (idx_t)nprobe;
887
- if (qres == nullptr || qres->qno != i) {
888
- qres = &pres.new_result(i);
889
- scanner->set_query(x + i * d);
959
+ for (idx_t iik = 0; iik < nx * (idx_t)cur_nprobe; iik++) {
960
+ try {
961
+ idx_t i = iik / (idx_t)cur_nprobe;
962
+ idx_t ik = iik % (idx_t)cur_nprobe;
963
+ if (qres == nullptr || qres->qno != i) {
964
+ qres = &pres.new_result(i);
965
+ scanner->set_query(x + i * d);
966
+ }
967
+ scan_list_func(i, ik, *qres);
968
+ } catch (...) {
969
+ omp_capture_exception(ex);
970
+ }
890
971
  }
891
- scan_list_func(i, ik, *qres);
972
+ } else {
973
+ FAISS_THROW_FMT(
974
+ "parallel_mode %d not supported\n", parallel_mode);
892
975
  }
893
- } else {
894
- FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
895
- }
896
- if (parallel_mode == 0) {
897
- pres.finalize();
898
- } else {
976
+ if (parallel_mode == 0) {
977
+ pres.finalize();
978
+ } else {
899
979
  #pragma omp barrier
900
980
  #pragma omp single
901
- RangeSearchPartialResult::merge(all_pres, false);
981
+ RangeSearchPartialResult::merge(all_pres, false);
902
982
  #pragma omp barrier
983
+ }
984
+ } catch (...) {
985
+ omp_capture_exception(ex);
903
986
  }
904
987
  }
905
988
 
906
- if (interrupt) {
907
- if (!exception_string.empty()) {
908
- FAISS_THROW_FMT(
909
- "search interrupted with: %s", exception_string.c_str());
910
- } else {
911
- FAISS_THROW_MSG("computation interrupted");
912
- }
913
- }
989
+ omp_rethrow_if_exception(ex);
914
990
 
915
991
  if (stats == nullptr) {
916
992
  stats = &indexIVF_stats;
@@ -920,6 +996,57 @@ void IndexIVF::range_search_preassigned(
920
996
  stats->ndis += ndis;
921
997
  }
922
998
 
999
+ void IndexIVF::search1(
1000
+ const float* x,
1001
+ ResultHandler& handler,
1002
+ SearchParameters* params_in) const {
1003
+ const IVFSearchParameters* params = nullptr;
1004
+ const SearchParameters* quantizer_params = nullptr;
1005
+ if (params_in) {
1006
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
1007
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1008
+ quantizer_params = params->quantizer_params;
1009
+ }
1010
+ const size_t cur_nprobe =
1011
+ std::min(nlist, params ? params->nprobe : this->nprobe);
1012
+ size_t nx = 1;
1013
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
1014
+ std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
1015
+
1016
+ double t0 = getmillisecs();
1017
+ quantizer->search(
1018
+ nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
1019
+ indexIVF_stats.quantization_time += getmillisecs() - t0;
1020
+
1021
+ t0 = getmillisecs();
1022
+ invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
1023
+
1024
+ std::unique_ptr<InvertedListScanner> scanner(
1025
+ get_InvertedListScanner(false, nullptr, params));
1026
+ scanner->set_query(x);
1027
+
1028
+ for (size_t i = 0; i < cur_nprobe; i++) {
1029
+ idx_t key = keys[i];
1030
+ FAISS_THROW_IF_NOT_FMT(
1031
+ key < (idx_t)nlist,
1032
+ "Invalid key=%" PRId64 " nlist=%zd\n",
1033
+ key,
1034
+ nlist);
1035
+ if (key < 0 || invlists->is_empty(key)) {
1036
+ continue;
1037
+ }
1038
+
1039
+ scanner->set_list(key, coarse_dis[i]);
1040
+ InvertedLists::ScopedCodes scodes(invlists, key);
1041
+ InvertedLists::ScopedIds ids(invlists, key);
1042
+ size_t list_size = invlists->list_size(key);
1043
+
1044
+ scanner->scan_codes(list_size, scodes.get(), ids.get(), handler);
1045
+ }
1046
+
1047
+ indexIVF_stats.search_time += getmillisecs() - t0;
1048
+ }
1049
+
923
1050
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
924
1051
  bool /*store_pairs*/,
925
1052
  const IDSelector* /* sel */,
@@ -935,11 +1062,11 @@ void IndexIVF::reconstruct(idx_t key, float* recons) const {
935
1062
  void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
936
1063
  FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
937
1064
 
938
- for (idx_t list_no = 0; list_no < nlist; list_no++) {
1065
+ for (size_t list_no = 0; list_no < nlist; list_no++) {
939
1066
  size_t list_size = invlists->list_size(list_no);
940
1067
  ScopedIds idlist(invlists, list_no);
941
1068
 
942
- for (idx_t offset = 0; offset < list_size; offset++) {
1069
+ for (size_t offset = 0; offset < list_size; offset++) {
943
1070
  idx_t id = idlist[offset];
944
1071
  if (!(id >= i0 && id < i0 + ni)) {
945
1072
  continue;
@@ -1000,16 +1127,16 @@ void IndexIVF::search_and_reconstruct(
1000
1127
  params = dynamic_cast<const IVFSearchParameters*>(params_in);
1001
1128
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1002
1129
  }
1003
- const size_t nprobe =
1130
+ const size_t cur_nprobe =
1004
1131
  std::min(nlist, params ? params->nprobe : this->nprobe);
1005
- FAISS_THROW_IF_NOT(nprobe > 0);
1132
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
1006
1133
 
1007
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
1008
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1134
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
1135
+ std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
1009
1136
 
1010
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
1137
+ quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
1011
1138
 
1012
- invlists->prefetch_lists(idx.get(), n * nprobe);
1139
+ invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
1013
1140
 
1014
1141
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
1015
1142
  // and offset into `codes` for reconstruction
@@ -1031,8 +1158,8 @@ void IndexIVF::search_and_reconstruct(
1031
1158
  // Fill with NaNs
1032
1159
  memset(reconstructed, -1, sizeof(*reconstructed) * d);
1033
1160
  } else {
1034
- int list_no = lo_listno(key);
1035
- int offset = lo_offset(key);
1161
+ size_t list_no = lo_listno(key);
1162
+ size_t offset = lo_offset(key);
1036
1163
 
1037
1164
  // Update label to the actual id
1038
1165
  labels[ij] = invlists->get_single_id(list_no, offset);
@@ -1056,16 +1183,16 @@ void IndexIVF::search_and_return_codes(
1056
1183
  params = dynamic_cast<const IVFSearchParameters*>(params_in);
1057
1184
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1058
1185
  }
1059
- const size_t nprobe =
1186
+ const size_t cur_nprobe =
1060
1187
  std::min(nlist, params ? params->nprobe : this->nprobe);
1061
- FAISS_THROW_IF_NOT(nprobe > 0);
1188
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
1062
1189
 
1063
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
1064
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1190
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
1191
+ std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
1065
1192
 
1066
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
1193
+ quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
1067
1194
 
1068
- invlists->prefetch_lists(idx.get(), n * nprobe);
1195
+ invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
1069
1196
 
1070
1197
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
1071
1198
  // and offset into `codes` for reconstruction
@@ -1094,8 +1221,8 @@ void IndexIVF::search_and_return_codes(
1094
1221
  // Fill with 0xff
1095
1222
  memset(code1, -1, code_size_1);
1096
1223
  } else {
1097
- int list_no = lo_listno(key);
1098
- int offset = lo_offset(key);
1224
+ size_t list_no = lo_listno(key);
1225
+ size_t offset = lo_offset(key);
1099
1226
  const uint8_t* cc = invlists->get_single_code(list_no, offset);
1100
1227
 
1101
1228
  labels[ij] = invlists->get_single_id(list_no, offset);
@@ -1134,7 +1261,8 @@ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
1134
1261
  IDSelectorArray sel(n, new_ids);
1135
1262
  size_t nremove = remove_ids(sel);
1136
1263
  FAISS_THROW_IF_NOT_MSG(
1137
- nremove == n, "did not find all entries to remove");
1264
+ nremove == static_cast<size_t>(n),
1265
+ "did not find all entries to remove");
1138
1266
  add_with_ids(n, x, new_ids);
1139
1267
  return;
1140
1268
  }
@@ -1196,7 +1324,7 @@ idx_t IndexIVF::train_encoder_num_vectors() const {
1196
1324
  void IndexIVF::train_encoder(
1197
1325
  idx_t /*n*/,
1198
1326
  const float* /*x*/,
1199
- const idx_t* assign) {
1327
+ const idx_t* /*assign*/) {
1200
1328
  // does nothing by default
1201
1329
  if (verbose) {
1202
1330
  printf("IndexIVF: no residual training\n");
@@ -1298,6 +1426,20 @@ IndexIVFStats indexIVF_stats;
1298
1426
  * InvertedListScanner
1299
1427
  *************************************************************************/
1300
1428
 
1429
+ // this gets expanded in expanded_scanners
1430
+
1431
+ size_t InvertedListScanner::scan_codes(
1432
+ size_t list_size,
1433
+ const uint8_t* codes,
1434
+ const idx_t* ids,
1435
+ ResultHandler& handler) const {
1436
+ return run_scan_codes(*this, list_size, codes, ids, handler);
1437
+ }
1438
+
1439
+ void InvertedListScanner::set_list(idx_t list_no_in, float /* coarse_dis */) {
1440
+ this->list_no = list_no_in;
1441
+ }
1442
+
1301
1443
  size_t InvertedListScanner::scan_codes(
1302
1444
  size_t list_size,
1303
1445
  const uint8_t* codes,
@@ -1305,46 +1447,15 @@ size_t InvertedListScanner::scan_codes(
1305
1447
  float* simi,
1306
1448
  idx_t* idxi,
1307
1449
  size_t k) const {
1308
- size_t nup = 0;
1309
-
1310
1450
  if (!keep_max) {
1311
- for (size_t j = 0; j < list_size; j++) {
1312
- if (sel != nullptr) {
1313
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1314
- if (!sel->is_member(id)) {
1315
- codes += code_size;
1316
- continue;
1317
- }
1318
- }
1319
-
1320
- float dis = distance_to_code(codes);
1321
- if (dis < simi[0]) {
1322
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1323
- maxheap_replace_top(k, simi, idxi, dis, id);
1324
- nup++;
1325
- }
1326
- codes += code_size;
1327
- }
1451
+ using C = CMax<float, idx_t>;
1452
+ HeapResultHandler<C, false> handler(k, simi, idxi);
1453
+ return scan_codes(list_size, codes, ids, handler);
1328
1454
  } else {
1329
- for (size_t j = 0; j < list_size; j++) {
1330
- if (sel != nullptr) {
1331
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1332
- if (!sel->is_member(id)) {
1333
- codes += code_size;
1334
- continue;
1335
- }
1336
- }
1337
-
1338
- float dis = distance_to_code(codes);
1339
- if (dis > simi[0]) {
1340
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1341
- minheap_replace_top(k, simi, idxi, dis, id);
1342
- nup++;
1343
- }
1344
- codes += code_size;
1345
- }
1455
+ using C = CMin<float, idx_t>;
1456
+ HeapResultHandler<C, false> handler(k, simi, idxi);
1457
+ return scan_codes(list_size, codes, ids, handler);
1346
1458
  }
1347
- return nup;
1348
1459
  }
1349
1460
 
1350
1461
  size_t InvertedListScanner::iterate_codes(
@@ -1356,11 +1467,19 @@ size_t InvertedListScanner::iterate_codes(
1356
1467
  size_t nup = 0;
1357
1468
  list_size = 0;
1358
1469
 
1470
+ const bool has_cb = it->has_search_callbacks_;
1471
+
1359
1472
  if (!keep_max) {
1360
1473
  for (; it->is_available(); it->next()) {
1361
1474
  auto id_and_codes = it->get_id_and_codes();
1362
1475
  float dis = distance_to_code(id_and_codes.second);
1476
+ if (has_cb) {
1477
+ it->on_distance_computed(id_and_codes.first, dis);
1478
+ }
1363
1479
  if (dis < simi[0]) {
1480
+ if (has_cb) {
1481
+ it->on_heap_changed(id_and_codes.first, idxi[0]);
1482
+ }
1364
1483
  maxheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
1365
1484
  nup++;
1366
1485
  }
@@ -1370,7 +1489,13 @@ size_t InvertedListScanner::iterate_codes(
1370
1489
  for (; it->is_available(); it->next()) {
1371
1490
  auto id_and_codes = it->get_id_and_codes();
1372
1491
  float dis = distance_to_code(id_and_codes.second);
1492
+ if (has_cb) {
1493
+ it->on_distance_computed(id_and_codes.first, dis);
1494
+ }
1373
1495
  if (dis > simi[0]) {
1496
+ if (has_cb) {
1497
+ it->on_heap_changed(id_and_codes.first, idxi[0]);
1498
+ }
1374
1499
  minheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
1375
1500
  nup++;
1376
1501
  }
@@ -1386,16 +1511,18 @@ void InvertedListScanner::scan_codes_range(
1386
1511
  const idx_t* ids,
1387
1512
  float radius,
1388
1513
  RangeQueryResult& res) const {
1389
- for (size_t j = 0; j < list_size; j++) {
1390
- float dis = distance_to_code(codes);
1391
- bool keep = !keep_max
1392
- ? dis < radius
1393
- : dis > radius; // TODO templatize to remove this test
1394
- if (keep) {
1395
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1396
- res.add(dis, id);
1397
- }
1398
- codes += code_size;
1514
+ if (!keep_max) {
1515
+ using C = CMax<float, idx_t>;
1516
+ RangeResultHandler<C, false> handler(&res, radius);
1517
+ scan_codes(list_size, codes, ids, handler);
1518
+ res.stats.scan_cnt += handler.stats.scan_cnt;
1519
+ res.stats.nheap_updates += handler.stats.nheap_updates;
1520
+ } else {
1521
+ using C = CMin<float, idx_t>;
1522
+ RangeResultHandler<C, false> handler(&res, radius);
1523
+ scan_codes(list_size, codes, ids, handler);
1524
+ res.stats.scan_cnt += handler.stats.scan_cnt;
1525
+ res.stats.nheap_updates += handler.stats.nheap_updates;
1399
1526
  }
1400
1527
  }
1401
1528
 
@@ -1404,6 +1531,7 @@ void InvertedListScanner::iterate_codes_range(
1404
1531
  float radius,
1405
1532
  RangeQueryResult& res,
1406
1533
  size_t& list_size) const {
1534
+ size_t nup = 0;
1407
1535
  list_size = 0;
1408
1536
  for (; it->is_available(); it->next()) {
1409
1537
  auto id_and_codes = it->get_id_and_codes();
@@ -1413,9 +1541,11 @@ void InvertedListScanner::iterate_codes_range(
1413
1541
  : dis > radius; // TODO templatize to remove this test
1414
1542
  if (keep) {
1415
1543
  res.add(dis, id_and_codes.first);
1544
+ nup++;
1416
1545
  }
1417
1546
  list_size++;
1418
1547
  }
1548
+ res.stats.nheap_updates += nup;
1419
1549
  }
1420
1550
 
1421
1551
  } // namespace faiss