faiss 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (361) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +2 -1
  4. data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
  5. data/ext/faiss/index_binary.cpp +1 -1
  6. data/ext/faiss/kmeans.cpp +1 -1
  7. data/ext/faiss/pca_matrix.cpp +1 -1
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
  10. data/lib/faiss/version.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +93 -80
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -240
  13. data/vendor/faiss/faiss/Clustering.h +6 -0
  14. data/vendor/faiss/faiss/IVFlib.cpp +41 -21
  15. data/vendor/faiss/faiss/Index.cpp +6 -5
  16. data/vendor/faiss/faiss/Index.h +5 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  21. data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
  22. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
  32. data/vendor/faiss/faiss/IndexFastScan.h +25 -23
  33. data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
  34. data/vendor/faiss/faiss/IndexFlat.h +21 -18
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
  36. data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
  37. data/vendor/faiss/faiss/IndexHNSW.h +16 -2
  38. data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
  39. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  40. data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
  41. data/vendor/faiss/faiss/IndexIVF.h +33 -12
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
  45. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
  46. data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
  47. data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
  48. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  49. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
  50. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
  53. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  55. data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
  56. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
  57. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
  58. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
  59. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  62. data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
  63. data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  66. data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
  67. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
  68. data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
  69. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  73. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  74. data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
  75. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
  76. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
  77. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
  78. data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
  79. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  80. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  84. data/vendor/faiss/faiss/IndexShards.cpp +10 -9
  85. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  86. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  88. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  89. data/vendor/faiss/faiss/MetricType.h +14 -7
  90. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  91. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  92. data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
  93. data/vendor/faiss/faiss/VectorTransform.h +16 -16
  94. data/vendor/faiss/faiss/build.cpp +23 -0
  95. data/vendor/faiss/faiss/build.h +15 -0
  96. data/vendor/faiss/faiss/clone_index.cpp +48 -47
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  101. data/vendor/faiss/faiss/factory_tools.cpp +5 -0
  102. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  109. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  110. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  111. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  112. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  113. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  114. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  115. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  116. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  117. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  118. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  119. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  120. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  121. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  122. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  123. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  125. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
  127. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  128. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  129. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
  130. data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
  131. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
  132. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  133. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
  134. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  135. data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
  136. data/vendor/faiss/faiss/impl/HNSW.h +13 -34
  137. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  138. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  139. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
  141. data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
  142. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  143. data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
  144. data/vendor/faiss/faiss/impl/NSG.h +4 -4
  145. data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
  146. data/vendor/faiss/faiss/impl/Panorama.h +258 -87
  147. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  148. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  149. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
  150. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  151. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  152. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  153. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
  154. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  155. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
  156. data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
  157. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
  158. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
  159. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  160. data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
  161. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
  162. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
  163. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  164. data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
  165. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  166. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  167. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  168. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  169. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  170. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  171. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  172. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  173. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  174. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  175. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  176. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  177. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  178. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  179. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  180. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  181. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  182. data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
  183. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  184. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  185. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  186. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  187. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  188. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  189. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
  190. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  191. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  192. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  193. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  194. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  195. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  196. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
  197. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  198. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  199. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
  200. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  201. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  202. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  203. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  204. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  205. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  206. data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
  207. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
  208. data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
  209. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  210. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  211. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  212. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
  213. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  214. data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
  215. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  216. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  217. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  218. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  219. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  220. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  221. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  222. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  223. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
  224. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
  225. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
  226. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  227. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
  228. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
  229. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  230. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
  231. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  232. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
  233. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
  234. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
  235. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
  236. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
  237. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
  238. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
  239. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
  240. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
  241. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  242. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
  243. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
  244. data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
  245. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  246. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
  247. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  248. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  249. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  250. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
  251. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  252. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  253. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  254. data/vendor/faiss/faiss/index_factory.cpp +86 -18
  255. data/vendor/faiss/faiss/index_io.h +24 -0
  256. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
  257. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  258. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  259. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  260. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  261. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  262. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
  263. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  264. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
  265. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  266. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  267. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  268. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  269. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  270. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  271. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  272. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  273. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
  274. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  275. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
  276. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
  277. data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
  278. data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
  279. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  280. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  281. data/vendor/faiss/faiss/utils/distances.cpp +390 -560
  282. data/vendor/faiss/faiss/utils/distances.h +20 -1
  283. data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
  284. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  285. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  286. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  287. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  288. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  289. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  290. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
  291. data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
  292. data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
  293. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  294. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  295. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  296. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  297. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  298. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  299. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  300. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  301. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  302. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  303. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  304. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  305. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  306. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  307. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  308. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  309. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  310. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  311. data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
  312. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  313. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  314. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  315. data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
  316. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  317. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  318. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
  319. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
  320. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
  321. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
  322. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
  323. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  324. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  325. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
  326. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  327. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  328. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  329. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  330. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  331. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  332. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  333. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  339. data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
  340. data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
  341. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  342. data/vendor/faiss/faiss/utils/utils.cpp +5 -5
  343. data/vendor/faiss/faiss/utils/utils.h +3 -3
  344. metadata +119 -34
  345. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  346. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  347. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
  348. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
  349. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  350. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  351. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  352. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  353. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
  354. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  355. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  356. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
  357. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  358. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  359. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  360. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
  361. /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
@@ -8,9 +8,9 @@
8
8
  #include <faiss/IndexIVF.h>
9
9
 
10
10
  #include <omp.h>
11
+ #include <atomic>
11
12
  #include <cstdint>
12
13
  #include <memory>
13
- #include <mutex>
14
14
 
15
15
  #include <algorithm>
16
16
  #include <cinttypes>
@@ -37,8 +37,8 @@ using ScopedCodes = InvertedLists::ScopedCodes;
37
37
  * Level1Quantizer implementation
38
38
  ******************************************/
39
39
 
40
- Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
41
- : quantizer(quantizer), nlist(nlist) {
40
+ Level1Quantizer::Level1Quantizer(Index* quantizer_in, size_t nlist_in)
41
+ : quantizer(quantizer_in), nlist(nlist_in) {
42
42
  // here we set a low # iterations because this is typically used
43
43
  // for large clusterings (nb this is not used for the MultiIndex,
44
44
  // for which quantizer_trains_alone = true)
@@ -58,8 +58,10 @@ void Level1Quantizer::train_q1(
58
58
  const float* x,
59
59
  bool verbose,
60
60
  MetricType metric_type) {
61
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
61
62
  size_t d = quantizer->d;
62
- if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
63
+ if (quantizer->is_trained &&
64
+ (static_cast<size_t>(quantizer->ntotal) == nlist)) {
63
65
  if (verbose) {
64
66
  printf("IVF quantizer does not need training.\n");
65
67
  }
@@ -70,14 +72,14 @@ void Level1Quantizer::train_q1(
70
72
  quantizer->verbose = verbose;
71
73
  quantizer->train(n, x);
72
74
  FAISS_THROW_IF_NOT_MSG(
73
- quantizer->ntotal == nlist,
75
+ static_cast<size_t>(quantizer->ntotal) == nlist,
74
76
  "nlist not consistent with quantizer size");
75
77
  } else if (quantizer_trains_alone == 0) {
76
78
  if (verbose) {
77
79
  printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
78
80
  }
79
81
 
80
- Clustering clus(d, nlist, cp);
82
+ Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
81
83
  quantizer->reset();
82
84
  if (clustering_index) {
83
85
  clus.train(n, x, *clustering_index);
@@ -99,7 +101,7 @@ void Level1Quantizer::train_q1(
99
101
  metric_type == METRIC_L2 ||
100
102
  (metric_type == METRIC_INNER_PRODUCT && cp.spherical));
101
103
 
102
- Clustering clus(d, nlist, cp);
104
+ Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
103
105
  if (!clustering_index) {
104
106
  IndexFlatL2 assigner(d);
105
107
  clus.train(n, x, assigner);
@@ -148,7 +150,7 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
148
150
  nbit += 8;
149
151
  nl >>= 8;
150
152
  }
151
- FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
153
+ FAISS_THROW_IF_NOT(list_no >= 0 && static_cast<size_t>(list_no) < nlist);
152
154
  return list_no;
153
155
  }
154
156
 
@@ -157,21 +159,23 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
157
159
  ******************************************/
158
160
 
159
161
  IndexIVF::IndexIVF(
160
- Index* quantizer,
161
- size_t d,
162
- size_t nlist,
163
- size_t code_size,
162
+ Index* quantizer_in,
163
+ size_t d_in,
164
+ size_t nlist_in,
165
+ size_t code_size_in,
164
166
  MetricType metric,
165
- bool own_invlists)
166
- : Index(d, metric),
167
- IndexIVFInterface(quantizer, nlist),
167
+ bool own_invlists_in)
168
+ : Index(d_in, metric),
169
+ IndexIVFInterface(quantizer_in, nlist_in),
168
170
  invlists(
169
- own_invlists ? new ArrayInvertedLists(nlist, code_size)
170
- : nullptr),
171
- own_invlists(own_invlists),
172
- code_size(code_size) {
173
- FAISS_THROW_IF_NOT(d == quantizer->d);
174
- is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
171
+ own_invlists_in
172
+ ? new ArrayInvertedLists(nlist_in, code_size_in)
173
+ : nullptr),
174
+ own_invlists(own_invlists_in),
175
+ code_size(code_size_in) {
176
+ FAISS_THROW_IF_NOT(static_cast<int>(d_in) == quantizer_in->d);
177
+ is_trained = quantizer_in->is_trained &&
178
+ (static_cast<size_t>(quantizer_in->ntotal) == nlist_in);
175
179
  // Spherical by default if the metric is inner_product
176
180
  if (metric_type == METRIC_INNER_PRODUCT) {
177
181
  cp.spherical = true;
@@ -185,6 +189,8 @@ void IndexIVF::add(idx_t n, const float* x) {
185
189
  }
186
190
 
187
191
  void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
192
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
193
+ FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
188
194
  std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
189
195
  quantizer->assign(n, x, coarse_idx.get());
190
196
  add_core(n, x, xids, coarse_idx.get());
@@ -235,7 +241,7 @@ void IndexIVF::add_core(
235
241
 
236
242
  size_t nadd = 0, nminus1 = 0;
237
243
 
238
- for (size_t i = 0; i < n; i++) {
244
+ for (idx_t i = 0; i < n; i++) {
239
245
  if (coarse_idx[i] < 0) {
240
246
  nminus1++;
241
247
  }
@@ -252,7 +258,7 @@ void IndexIVF::add_core(
252
258
  int rank = omp_get_thread_num();
253
259
 
254
260
  // each thread takes care of a subset of lists
255
- for (size_t i = 0; i < n; i++) {
261
+ for (idx_t i = 0; i < n; i++) {
256
262
  idx_t list_no = coarse_idx[i];
257
263
  if (list_no >= 0 && list_no % nt == rank) {
258
264
  idx_t id = xids ? xids[i] : ntotal + i;
@@ -305,45 +311,49 @@ void IndexIVF::search(
305
311
  idx_t* labels,
306
312
  const SearchParameters* params_in) const {
307
313
  FAISS_THROW_IF_NOT(k > 0);
314
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
315
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
316
+ FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
308
317
  const IVFSearchParameters* params = nullptr;
309
318
  if (params_in) {
310
319
  params = dynamic_cast<const IVFSearchParameters*>(params_in);
311
320
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
312
321
  }
313
- const size_t nprobe =
322
+ const size_t cur_nprobe =
314
323
  std::min(nlist, params ? params->nprobe : this->nprobe);
315
- FAISS_THROW_IF_NOT(nprobe > 0);
324
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
316
325
 
317
326
  // search function for a subset of queries
318
- auto sub_search_func = [this, k, nprobe, params](
319
- idx_t n,
320
- const float* x,
321
- float* distances,
322
- idx_t* labels,
327
+ auto sub_search_func = [this, k, cur_nprobe, params](
328
+ idx_t sub_n,
329
+ const float* sub_x,
330
+ float* sub_distances,
331
+ idx_t* sub_labels,
323
332
  IndexIVFStats* ivf_stats) {
324
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
325
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
333
+ std::unique_ptr<idx_t[]> idx(new idx_t[sub_n * cur_nprobe]);
334
+ std::unique_ptr<float[]> coarse_dis(new float[sub_n * cur_nprobe]);
326
335
 
327
336
  double t0 = getmillisecs();
328
337
  quantizer->search(
329
- n,
330
- x,
331
- nprobe,
338
+ sub_n,
339
+ sub_x,
340
+ cur_nprobe,
332
341
  coarse_dis.get(),
333
342
  idx.get(),
334
343
  params ? params->quantizer_params : nullptr);
335
344
 
336
345
  double t1 = getmillisecs();
337
- invlists->prefetch_lists(idx.get(), n * nprobe);
346
+ invlists->prefetch_lists(
347
+ idx.get(), static_cast<int>(sub_n * cur_nprobe));
338
348
 
339
349
  search_preassigned(
340
- n,
341
- x,
350
+ sub_n,
351
+ sub_x,
342
352
  k,
343
353
  idx.get(),
344
354
  coarse_dis.get(),
345
- distances,
346
- labels,
355
+ sub_distances,
356
+ sub_labels,
347
357
  false,
348
358
  params,
349
359
  ivf_stats);
@@ -355,32 +365,28 @@ void IndexIVF::search(
355
365
  if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
356
366
  int nt = std::min(omp_get_max_threads(), int(n));
357
367
  std::vector<IndexIVFStats> stats(nt);
358
- std::mutex exception_mutex;
359
- std::string exception_string;
368
+ std::exception_ptr ex;
360
369
 
361
370
  #pragma omp parallel for if (nt > 1)
362
371
  for (idx_t slice = 0; slice < nt; slice++) {
363
- IndexIVFStats local_stats;
364
- idx_t i0 = n * slice / nt;
365
- idx_t i1 = n * (slice + 1) / nt;
366
- if (i1 > i0) {
367
- try {
372
+ try {
373
+ IndexIVFStats local_stats;
374
+ idx_t i0 = n * slice / nt;
375
+ idx_t i1 = n * (slice + 1) / nt;
376
+ if (i1 > i0) {
368
377
  sub_search_func(
369
378
  i1 - i0,
370
379
  x + i0 * d,
371
380
  distances + i0 * k,
372
381
  labels + i0 * k,
373
382
  &stats[slice]);
374
- } catch (const std::exception& e) {
375
- std::lock_guard<std::mutex> lock(exception_mutex);
376
- exception_string = e.what();
377
383
  }
384
+ } catch (...) {
385
+ omp_capture_exception(ex);
378
386
  }
379
387
  }
380
388
 
381
- if (!exception_string.empty()) {
382
- FAISS_THROW_MSG(exception_string.c_str());
383
- }
389
+ omp_rethrow_if_exception(ex);
384
390
 
385
391
  // collect stats
386
392
  for (idx_t slice = 0; slice < nt; slice++) {
@@ -405,13 +411,17 @@ void IndexIVF::search_preassigned(
405
411
  const IVFSearchParameters* params,
406
412
  IndexIVFStats* ivf_stats) const {
407
413
  FAISS_THROW_IF_NOT(k > 0);
414
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
415
+ FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
408
416
 
409
- idx_t nprobe = params ? params->nprobe : this->nprobe;
410
- nprobe = std::min((idx_t)nlist, nprobe);
411
- FAISS_THROW_IF_NOT(nprobe > 0);
417
+ idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
418
+ cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
419
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
412
420
 
413
421
  const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
414
- idx_t max_codes = params ? params->max_codes : this->max_codes;
422
+ idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
423
+ const bool ensure_topk_full = params ? params->ensure_topk_full : false;
424
+
415
425
  IDSelector* sel = params ? params->sel : nullptr;
416
426
  const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
417
427
  if (selr) {
@@ -427,7 +437,8 @@ void IndexIVF::search_preassigned(
427
437
  "selector and store_pairs cannot be combined");
428
438
 
429
439
  FAISS_THROW_IF_NOT_MSG(
430
- !invlists->use_iterator || (max_codes == 0 && store_pairs == false),
440
+ !invlists->use_iterator ||
441
+ (cur_max_codes == 0 && store_pairs == false),
431
442
  "iterable inverted lists don't support max_codes and store_pairs");
432
443
 
433
444
  size_t nlistv = 0, ndis = 0, nheap = 0;
@@ -435,106 +446,119 @@ void IndexIVF::search_preassigned(
435
446
  using HeapForIP = CMin<float, idx_t>;
436
447
  using HeapForL2 = CMax<float, idx_t>;
437
448
 
438
- bool interrupt = false;
439
- std::mutex exception_mutex;
440
- std::string exception_string;
449
+ std::exception_ptr ex;
450
+ std::atomic<bool> interrupt{false};
441
451
 
442
452
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
443
453
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
444
454
 
445
455
  FAISS_THROW_IF_NOT_MSG(
446
- max_codes == 0 || pmode == 0 || pmode == 3,
456
+ cur_max_codes == 0 || pmode == 0 || pmode == 3,
447
457
  "max_codes supported only for parallel_mode = 0 or 3");
448
458
 
449
- if (max_codes == 0) {
450
- max_codes = unlimited_list_size;
459
+ FAISS_THROW_IF_NOT_MSG(
460
+ !ensure_topk_full || pmode == 0 || pmode == 3,
461
+ "ensure_topk_full supported only for parallel_mode = 0 or 3");
462
+
463
+ if (cur_max_codes == 0) {
464
+ cur_max_codes = unlimited_list_size;
451
465
  }
466
+ // Budget used by the probe loop below. ensure_topk_full makes a small
467
+ // max_codes budget large enough to give k post-filter candidates a chance.
468
+ idx_t effective_max_codes =
469
+ ensure_topk_full ? std::max(cur_max_codes, k) : cur_max_codes;
452
470
 
453
471
  [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
454
472
  (pmode == 0 ? false
455
473
  : pmode == 3 ? n > 1
456
- : pmode == 1 ? nprobe > 1
457
- : nprobe * n > 1);
474
+ : pmode == 1 ? cur_nprobe > 1
475
+ : cur_nprobe * n > 1);
458
476
 
459
477
  void* inverted_list_context =
460
478
  params ? params->inverted_list_context : nullptr;
461
479
 
462
480
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
463
481
  {
464
- std::unique_ptr<InvertedListScanner> scanner(
465
- get_InvertedListScanner(store_pairs, sel, params));
466
-
467
- /*****************************************************
468
- * Depending on parallel_mode, there are two possible ways
469
- * to organize the search. Here we define local functions
470
- * that are in common between the two
471
- ******************************************************/
472
-
473
- // initialize + reorder a result heap
482
+ // C++ exceptions that escape an OpenMP parallel region without being
483
+ // caught inside it call std::terminate — they cannot propagate across
484
+ // thread boundaries. The outer try/catch covers per-thread setup
485
+ // (scanner creation, set_query); the inner try/catch in scan_one_list
486
+ // covers per-list operations. Both set interrupt=true to stop further
487
+ // work and re-throw after the parallel region exits.
488
+ try {
489
+ std::unique_ptr<InvertedListScanner> scanner(
490
+ get_InvertedListScanner(store_pairs, sel, params));
491
+
492
+ /*****************************************************
493
+ * Depending on parallel_mode, there are two possible ways
494
+ * to organize the search. Here we define local functions
495
+ * that are in common between the two
496
+ ******************************************************/
497
+
498
+ // initialize + reorder a result heap
499
+
500
+ auto init_result = [&](float* simi, idx_t* idxi) {
501
+ if (!do_heap_init) {
502
+ return;
503
+ }
504
+ if (metric_type == METRIC_INNER_PRODUCT) {
505
+ heap_heapify<HeapForIP>(k, simi, idxi);
506
+ } else {
507
+ heap_heapify<HeapForL2>(k, simi, idxi);
508
+ }
509
+ };
510
+
511
+ auto add_local_results = [&](const float* local_dis,
512
+ const idx_t* local_idx,
513
+ float* simi,
514
+ idx_t* idxi) {
515
+ if (metric_type == METRIC_INNER_PRODUCT) {
516
+ heap_addn<HeapForIP>(
517
+ k, simi, idxi, local_dis, local_idx, k);
518
+ } else {
519
+ heap_addn<HeapForL2>(
520
+ k, simi, idxi, local_dis, local_idx, k);
521
+ }
522
+ };
474
523
 
475
- auto init_result = [&](float* simi, idx_t* idxi) {
476
- if (!do_heap_init) {
477
- return;
478
- }
479
- if (metric_type == METRIC_INNER_PRODUCT) {
480
- heap_heapify<HeapForIP>(k, simi, idxi);
481
- } else {
482
- heap_heapify<HeapForL2>(k, simi, idxi);
483
- }
484
- };
524
+ auto reorder_result = [&](float* simi, idx_t* idxi) {
525
+ if (!do_heap_init) {
526
+ return;
527
+ }
528
+ if (metric_type == METRIC_INNER_PRODUCT) {
529
+ heap_reorder<HeapForIP>(k, simi, idxi);
530
+ } else {
531
+ heap_reorder<HeapForL2>(k, simi, idxi);
532
+ }
533
+ };
485
534
 
486
- auto add_local_results = [&](const float* local_dis,
487
- const idx_t* local_idx,
535
+ // single list scan using the current scanner (with query
536
+ // set properly) and storing results in simi and idxi
537
+ auto scan_one_list = [&](idx_t key,
538
+ float coarse_dis_i,
488
539
  float* simi,
489
- idx_t* idxi) {
490
- if (metric_type == METRIC_INNER_PRODUCT) {
491
- heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
492
- } else {
493
- heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
494
- }
495
- };
496
-
497
- auto reorder_result = [&](float* simi, idx_t* idxi) {
498
- if (!do_heap_init) {
499
- return;
500
- }
501
- if (metric_type == METRIC_INNER_PRODUCT) {
502
- heap_reorder<HeapForIP>(k, simi, idxi);
503
- } else {
504
- heap_reorder<HeapForL2>(k, simi, idxi);
505
- }
506
- };
507
-
508
- // single list scan using the current scanner (with query
509
- // set properly) and storing results in simi and idxi
510
- auto scan_one_list = [&](idx_t key,
511
- float coarse_dis_i,
512
- float* simi,
513
- idx_t* idxi,
514
- idx_t list_size_max) {
515
- if (key < 0) {
516
- // not enough centroids for multiprobe
517
- return (size_t)0;
518
- }
519
- FAISS_THROW_IF_NOT_FMT(
520
- key < (idx_t)nlist,
521
- "Invalid key=%" PRId64 " nlist=%zd\n",
522
- key,
523
- nlist);
524
-
525
- // don't waste time on empty lists
526
- if (invlists->is_empty(key, inverted_list_context)) {
527
- return (size_t)0;
528
- }
529
-
530
- scanner->set_list(key, coarse_dis_i);
540
+ idx_t* idxi,
541
+ idx_t list_size_max) {
542
+ if (key < 0) {
543
+ // not enough centroids for multiprobe
544
+ return (size_t)0;
545
+ }
546
+ FAISS_THROW_IF_NOT_FMT(
547
+ key < (idx_t)nlist,
548
+ "Invalid key=%" PRId64 " nlist=%zd\n",
549
+ key,
550
+ nlist);
551
+
552
+ // don't waste time on empty lists
553
+ if (invlists->is_empty(key, inverted_list_context)) {
554
+ return (size_t)0;
555
+ }
531
556
 
532
- nlistv++;
557
+ scanner->set_list(key, coarse_dis_i);
533
558
 
534
- try {
559
+ nlistv++;
535
560
  if (invlists->use_iterator) {
536
561
  size_t list_size = 0;
537
-
538
562
  std::unique_ptr<InvertedListsIterator> it(
539
563
  invlists->get_iterator(key, inverted_list_context));
540
564
 
@@ -544,8 +568,8 @@ void IndexIVF::search_preassigned(
544
568
  return list_size;
545
569
  } else {
546
570
  size_t list_size = invlists->list_size(key);
547
- if (list_size > list_size_max) {
548
- list_size = list_size_max;
571
+ if (list_size > static_cast<size_t>(list_size_max)) {
572
+ list_size = static_cast<size_t>(list_size_max);
549
573
  }
550
574
 
551
575
  InvertedLists::ScopedCodes scodes(invlists, key);
@@ -573,144 +597,167 @@ void IndexIVF::search_preassigned(
573
597
  ids += jmin;
574
598
  }
575
599
 
576
- nheap += scanner->scan_codes(
577
- list_size, codes, ids, simi, idxi, k);
578
-
579
- return list_size;
600
+ size_t old_scan_cnt = 0;
601
+ size_t old_heap_updates = 0;
602
+ if (metric_type == METRIC_INNER_PRODUCT) {
603
+ HeapResultHandler<HeapForIP, false> handler(
604
+ k, simi, idxi);
605
+ old_scan_cnt = handler.stats.scan_cnt;
606
+ old_heap_updates = handler.stats.nheap_updates;
607
+ scanner->scan_codes(list_size, codes, ids, handler);
608
+ nheap += handler.stats.nheap_updates - old_heap_updates;
609
+ return handler.stats.scan_cnt - old_scan_cnt;
610
+ } else {
611
+ HeapResultHandler<HeapForL2, false> handler(
612
+ k, simi, idxi);
613
+ old_scan_cnt = handler.stats.scan_cnt;
614
+ old_heap_updates = handler.stats.nheap_updates;
615
+ scanner->scan_codes(list_size, codes, ids, handler);
616
+ nheap += handler.stats.nheap_updates - old_heap_updates;
617
+ return handler.stats.scan_cnt - old_scan_cnt;
618
+ }
580
619
  }
581
- } catch (const std::exception& e) {
582
- std::lock_guard<std::mutex> lock(exception_mutex);
583
- exception_string =
584
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
585
- interrupt = true;
586
- return size_t(0);
587
- }
588
- };
620
+ };
589
621
 
590
- /****************************************************
591
- * Actual loops, depending on parallel_mode
592
- ****************************************************/
622
+ /****************************************************
623
+ * Actual loops, depending on parallel_mode
624
+ ****************************************************/
593
625
 
594
- if (pmode == 0 || pmode == 3) {
626
+ if (pmode == 0 || pmode == 3) {
595
627
  #pragma omp for
596
- for (idx_t i = 0; i < n; i++) {
597
- if (interrupt) {
598
- continue;
599
- }
600
-
601
- // loop over queries
602
- scanner->set_query(x + i * d);
603
- float* simi = distances + i * k;
604
- idx_t* idxi = labels + i * k;
605
-
606
- init_result(simi, idxi);
607
-
608
- idx_t nscan = 0;
609
-
610
- // loop over probes
611
- for (size_t ik = 0; ik < nprobe; ik++) {
612
- nscan += scan_one_list(
613
- keys[i * nprobe + ik],
614
- coarse_dis[i * nprobe + ik],
615
- simi,
616
- idxi,
617
- max_codes - nscan);
618
- if (nscan >= max_codes) {
619
- break;
628
+ for (idx_t i = 0; i < n; i++) {
629
+ if (interrupt.load(std::memory_order_relaxed)) {
630
+ continue;
620
631
  }
621
- }
622
-
623
- ndis += nscan;
624
- reorder_result(simi, idxi);
632
+ try {
633
+ // loop over queries
634
+ scanner->set_query(x + i * d);
635
+ float* simi = distances + i * k;
636
+ idx_t* idxi = labels + i * k;
637
+
638
+ init_result(simi, idxi);
639
+
640
+ idx_t nscan = 0;
641
+
642
+ // loop over probes
643
+ for (idx_t ik = 0; ik < cur_nprobe; ik++) {
644
+ // For soft budgets, scan whole lists so
645
+ // IDSelector-filtered rows do not consume the
646
+ // remaining code budget.
647
+ const idx_t list_size_max = ensure_topk_full
648
+ ? unlimited_list_size
649
+ : effective_max_codes - nscan;
650
+ nscan += scan_one_list(
651
+ keys[i * cur_nprobe + ik],
652
+ coarse_dis[i * cur_nprobe + ik],
653
+ simi,
654
+ idxi,
655
+ list_size_max);
656
+
657
+ // Early-stop check: apply max_codes after each
658
+ // list. nscan is the number of distances
659
+ // actually computed.
660
+ if (nscan >= effective_max_codes) {
661
+ break;
662
+ }
663
+ }
625
664
 
626
- if (InterruptCallback::is_interrupted()) {
627
- interrupt = true;
628
- }
665
+ ndis += nscan;
666
+ reorder_result(simi, idxi);
629
667
 
630
- } // parallel for
631
- } else if (pmode == 1) {
632
- std::vector<idx_t> local_idx(k);
633
- std::vector<float> local_dis(k);
668
+ InterruptCallback::check();
669
+ } catch (...) {
670
+ omp_capture_exception(ex, [&] { interrupt = true; });
671
+ }
672
+ } // parallel for
673
+ } else if (pmode == 1) {
674
+ std::vector<idx_t> local_idx(k);
675
+ std::vector<float> local_dis(k);
634
676
 
635
- for (size_t i = 0; i < n; i++) {
636
- scanner->set_query(x + i * d);
637
- init_result(local_dis.data(), local_idx.data());
677
+ for (idx_t i = 0; i < n; i++) {
678
+ scanner->set_query(x + i * d);
679
+ init_result(local_dis.data(), local_idx.data());
638
680
 
639
681
  #pragma omp for schedule(dynamic)
640
- for (idx_t ik = 0; ik < nprobe; ik++) {
641
- ndis += scan_one_list(
642
- keys[i * nprobe + ik],
643
- coarse_dis[i * nprobe + ik],
644
- local_dis.data(),
645
- local_idx.data(),
646
- unlimited_list_size);
647
-
648
- // can't do the test on max_codes
649
- }
650
- // merge thread-local results
682
+ for (idx_t ik = 0; ik < cur_nprobe; ik++) {
683
+ try {
684
+ ndis += scan_one_list(
685
+ keys[i * cur_nprobe + ik],
686
+ coarse_dis[i * cur_nprobe + ik],
687
+ local_dis.data(),
688
+ local_idx.data(),
689
+ unlimited_list_size);
690
+
691
+ // can't do the test on max_codes
692
+ } catch (...) {
693
+ omp_capture_exception(
694
+ ex, [&] { interrupt = true; });
695
+ }
696
+ }
697
+ // merge thread-local results
651
698
 
652
- float* simi = distances + i * k;
653
- idx_t* idxi = labels + i * k;
699
+ float* simi = distances + i * k;
700
+ idx_t* idxi = labels + i * k;
654
701
  #pragma omp single
655
- init_result(simi, idxi);
702
+ init_result(simi, idxi);
656
703
 
657
704
  #pragma omp barrier
658
705
  #pragma omp critical
659
- {
660
- add_local_results(
661
- local_dis.data(), local_idx.data(), simi, idxi);
662
- }
706
+ {
707
+ add_local_results(
708
+ local_dis.data(), local_idx.data(), simi, idxi);
709
+ }
663
710
  #pragma omp barrier
664
711
  #pragma omp single
665
- reorder_result(simi, idxi);
666
- }
667
- } else if (pmode == 2) {
668
- std::vector<idx_t> local_idx(k);
669
- std::vector<float> local_dis(k);
712
+ reorder_result(simi, idxi);
713
+ }
714
+ } else if (pmode == 2) {
715
+ std::vector<idx_t> local_idx(k);
716
+ std::vector<float> local_dis(k);
670
717
 
671
718
  #pragma omp single
672
- for (int64_t i = 0; i < n; i++) {
673
- init_result(distances + i * k, labels + i * k);
674
- }
719
+ for (int64_t i = 0; i < n; i++) {
720
+ init_result(distances + i * k, labels + i * k);
721
+ }
675
722
 
676
723
  #pragma omp for schedule(dynamic)
677
- for (int64_t ij = 0; ij < n * nprobe; ij++) {
678
- size_t i = ij / nprobe;
679
-
680
- scanner->set_query(x + i * d);
681
- init_result(local_dis.data(), local_idx.data());
682
- ndis += scan_one_list(
683
- keys[ij],
684
- coarse_dis[ij],
685
- local_dis.data(),
686
- local_idx.data(),
687
- unlimited_list_size);
724
+ for (int64_t ij = 0; ij < n * cur_nprobe; ij++) {
725
+ try {
726
+ size_t i = ij / cur_nprobe;
727
+
728
+ scanner->set_query(x + i * d);
729
+ init_result(local_dis.data(), local_idx.data());
730
+ ndis += scan_one_list(
731
+ keys[ij],
732
+ coarse_dis[ij],
733
+ local_dis.data(),
734
+ local_idx.data(),
735
+ unlimited_list_size);
688
736
  #pragma omp critical
689
- {
690
- add_local_results(
691
- local_dis.data(),
692
- local_idx.data(),
693
- distances + i * k,
694
- labels + i * k);
737
+ {
738
+ add_local_results(
739
+ local_dis.data(),
740
+ local_idx.data(),
741
+ distances + i * k,
742
+ labels + i * k);
743
+ }
744
+ } catch (...) {
745
+ omp_capture_exception(ex, [&] { interrupt = true; });
746
+ }
695
747
  }
696
- }
697
748
  #pragma omp single
698
- for (int64_t i = 0; i < n; i++) {
699
- reorder_result(distances + i * k, labels + i * k);
749
+ for (int64_t i = 0; i < n; i++) {
750
+ reorder_result(distances + i * k, labels + i * k);
751
+ }
752
+ } else {
753
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
700
754
  }
701
- } else {
702
- FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
755
+ } catch (...) {
756
+ omp_capture_exception(ex, [&] { interrupt = true; });
703
757
  }
704
758
  } // parallel section
705
759
 
706
- if (interrupt) {
707
- if (!exception_string.empty()) {
708
- FAISS_THROW_FMT(
709
- "search interrupted with: %s", exception_string.c_str());
710
- } else {
711
- FAISS_THROW_MSG("computation interrupted");
712
- }
713
- }
760
+ omp_rethrow_if_exception(ex);
714
761
 
715
762
  if (ivf_stats == nullptr) {
716
763
  ivf_stats = &indexIVF_stats;
@@ -727,6 +774,8 @@ void IndexIVF::range_search(
727
774
  float radius,
728
775
  RangeSearchResult* result,
729
776
  const SearchParameters* params_in) const {
777
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
778
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
730
779
  const IVFSearchParameters* params = nullptr;
731
780
  const SearchParameters* quantizer_params = nullptr;
732
781
  if (params_in) {
@@ -734,18 +783,18 @@ void IndexIVF::range_search(
734
783
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
735
784
  quantizer_params = params->quantizer_params;
736
785
  }
737
- const size_t nprobe =
786
+ const size_t cur_nprobe =
738
787
  std::min(nlist, params ? params->nprobe : this->nprobe);
739
- std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
740
- std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
788
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
789
+ std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
741
790
 
742
791
  double t0 = getmillisecs();
743
792
  quantizer->search(
744
- nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
793
+ nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
745
794
  indexIVF_stats.quantization_time += getmillisecs() - t0;
746
795
 
747
796
  t0 = getmillisecs();
748
- invlists->prefetch_lists(keys.get(), nx * nprobe);
797
+ invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
749
798
 
750
799
  range_search_preassigned(
751
800
  nx,
@@ -771,22 +820,29 @@ void IndexIVF::range_search_preassigned(
771
820
  bool store_pairs,
772
821
  const IVFSearchParameters* params,
773
822
  IndexIVFStats* stats) const {
774
- idx_t nprobe = params ? params->nprobe : this->nprobe;
775
- nprobe = std::min((idx_t)nlist, nprobe);
776
- FAISS_THROW_IF_NOT(nprobe > 0);
777
-
778
- idx_t max_codes = params ? params->max_codes : this->max_codes;
823
+ FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
824
+ idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
825
+ cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
826
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
827
+
828
+ idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
829
+ // Range-search early-stop budget. 0 disables the empty-bucket stop.
830
+ const size_t max_empty_result_buckets =
831
+ params ? params->max_empty_result_buckets : 0;
779
832
  IDSelector* sel = params ? params->sel : nullptr;
780
833
 
781
834
  FAISS_THROW_IF_NOT_MSG(
782
- !invlists->use_iterator || (max_codes == 0 && store_pairs == false),
835
+ !invlists->use_iterator ||
836
+ (cur_max_codes == 0 && store_pairs == false),
783
837
  "iterable inverted lists don't support max_codes and store_pairs");
784
838
 
839
+ FAISS_THROW_IF_NOT_MSG(
840
+ max_empty_result_buckets == 0 || parallel_mode == 0,
841
+ "max_empty_result_buckets supported only for parallel_mode = 0");
842
+
785
843
  size_t nlistv = 0, ndis = 0;
786
844
 
787
- bool interrupt = false;
788
- std::mutex exception_mutex;
789
- std::string exception_string;
845
+ std::exception_ptr ex;
790
846
 
791
847
  std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
792
848
 
@@ -795,122 +851,142 @@ void IndexIVF::range_search_preassigned(
795
851
  [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
796
852
  (pmode == 3 ? false
797
853
  : pmode == 0 ? nx > 1
798
- : pmode == 1 ? nprobe > 1
799
- : nprobe * nx > 1);
854
+ : pmode == 1 ? cur_nprobe > 1
855
+ : cur_nprobe * nx > 1);
800
856
 
801
857
  void* inverted_list_context =
802
858
  params ? params->inverted_list_context : nullptr;
803
859
 
804
860
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
805
861
  {
806
- RangeSearchPartialResult pres(result);
807
- std::unique_ptr<InvertedListScanner> scanner(
808
- get_InvertedListScanner(store_pairs, sel, params));
809
- FAISS_THROW_IF_NOT(scanner.get());
810
- all_pres[omp_get_thread_num()] = &pres;
811
-
812
- // prepare the list scanning function
813
-
814
- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
815
- idx_t key = keys[i * nprobe + ik]; /* select the list */
816
- if (key < 0) {
817
- return;
818
- }
819
- FAISS_THROW_IF_NOT_FMT(
820
- key < (idx_t)nlist,
821
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
822
- key,
823
- ik,
824
- nlist);
825
-
826
- if (invlists->is_empty(key, inverted_list_context)) {
827
- return;
828
- }
862
+ try {
863
+ RangeSearchPartialResult pres(result);
864
+ std::unique_ptr<InvertedListScanner> scanner(
865
+ get_InvertedListScanner(store_pairs, sel, params));
866
+ FAISS_THROW_IF_NOT(scanner.get());
867
+ all_pres[omp_get_thread_num()] = &pres;
868
+
869
+ // prepare the list scanning function
870
+
871
+ auto scan_list_func = [&](size_t i,
872
+ size_t ik,
873
+ RangeQueryResult& qres) {
874
+ idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
875
+ if (key < 0) {
876
+ return;
877
+ }
829
878
 
830
- try {
831
- size_t list_size = 0;
832
- scanner->set_list(key, coarse_dis[i * nprobe + ik]);
879
+ FAISS_THROW_IF_NOT_FMT(
880
+ key < (idx_t)nlist,
881
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
882
+ key,
883
+ ik,
884
+ nlist);
885
+
886
+ if (invlists->is_empty(key, inverted_list_context)) {
887
+ return;
888
+ }
889
+
890
+ scanner->set_list(key, coarse_dis[i * cur_nprobe + ik]);
891
+ const size_t scan_cnt0 = qres.stats.scan_cnt;
833
892
  if (invlists->use_iterator) {
893
+ size_t list_size = 0;
834
894
  std::unique_ptr<InvertedListsIterator> it(
835
895
  invlists->get_iterator(key, inverted_list_context));
836
896
 
837
897
  scanner->iterate_codes_range(
838
898
  it.get(), radius, qres, list_size);
899
+ qres.stats.scan_cnt += list_size;
839
900
  } else {
840
901
  InvertedLists::ScopedCodes scodes(invlists, key);
841
902
  InvertedLists::ScopedIds ids(invlists, key);
842
- list_size = invlists->list_size(key);
903
+ size_t list_size = invlists->list_size(key);
843
904
 
844
905
  scanner->scan_codes_range(
845
906
  list_size, scodes.get(), ids.get(), radius, qres);
846
907
  }
847
908
  nlistv++;
848
- ndis += list_size;
849
- } catch (const std::exception& e) {
850
- std::lock_guard<std::mutex> lock(exception_mutex);
851
- exception_string =
852
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
853
- interrupt = true;
854
- }
855
- };
909
+ ndis += qres.stats.scan_cnt - scan_cnt0;
910
+ };
856
911
 
857
- if (parallel_mode == 0) {
912
+ if (parallel_mode == 0) {
858
913
  #pragma omp for
859
- for (idx_t i = 0; i < nx; i++) {
860
- scanner->set_query(x + i * d);
861
-
862
- RangeQueryResult& qres = pres.new_result(i);
863
-
864
- for (size_t ik = 0; ik < nprobe; ik++) {
865
- scan_list_func(i, ik, qres);
914
+ for (idx_t i = 0; i < nx; i++) {
915
+ try {
916
+ scanner->set_query(x + i * d);
917
+ RangeQueryResult& qres = pres.new_result(i);
918
+
919
+ // Stop after enough consecutive probes add no range
920
+ // results. A hit resets the counter.
921
+ size_t prev_nres = qres.nres;
922
+ size_t ndup = 0;
923
+ for (idx_t ik = 0; ik < cur_nprobe; ik++) {
924
+ scan_list_func(i, ik, qres);
925
+ if (max_empty_result_buckets > 0) {
926
+ // Early-stop check: stop range search after
927
+ // enough consecutive empty probes.
928
+ ndup = (qres.nres == prev_nres) ? ndup + 1 : 0;
929
+ if (ndup >= max_empty_result_buckets) {
930
+ break;
931
+ }
932
+ prev_nres = qres.nres;
933
+ }
934
+ }
935
+ } catch (...) {
936
+ omp_capture_exception(ex);
937
+ }
866
938
  }
867
- }
868
939
 
869
- } else if (parallel_mode == 1) {
870
- for (size_t i = 0; i < nx; i++) {
871
- scanner->set_query(x + i * d);
940
+ } else if (parallel_mode == 1) {
941
+ for (idx_t i = 0; i < nx; i++) {
942
+ scanner->set_query(x + i * d);
872
943
 
873
- RangeQueryResult& qres = pres.new_result(i);
944
+ RangeQueryResult& qres = pres.new_result(i);
874
945
 
875
946
  #pragma omp for schedule(dynamic)
876
- for (int64_t ik = 0; ik < nprobe; ik++) {
877
- scan_list_func(i, ik, qres);
947
+ for (int64_t ik = 0; ik < cur_nprobe; ik++) {
948
+ try {
949
+ scan_list_func(i, ik, qres);
950
+ } catch (...) {
951
+ omp_capture_exception(ex);
952
+ }
953
+ }
878
954
  }
879
- }
880
- } else if (parallel_mode == 2) {
881
- RangeQueryResult* qres = nullptr;
955
+ } else if (parallel_mode == 2) {
956
+ RangeQueryResult* qres = nullptr;
882
957
 
883
958
  #pragma omp for schedule(dynamic)
884
- for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
885
- idx_t i = iik / (idx_t)nprobe;
886
- idx_t ik = iik % (idx_t)nprobe;
887
- if (qres == nullptr || qres->qno != i) {
888
- qres = &pres.new_result(i);
889
- scanner->set_query(x + i * d);
959
+ for (idx_t iik = 0; iik < nx * (idx_t)cur_nprobe; iik++) {
960
+ try {
961
+ idx_t i = iik / (idx_t)cur_nprobe;
962
+ idx_t ik = iik % (idx_t)cur_nprobe;
963
+ if (qres == nullptr || qres->qno != i) {
964
+ qres = &pres.new_result(i);
965
+ scanner->set_query(x + i * d);
966
+ }
967
+ scan_list_func(i, ik, *qres);
968
+ } catch (...) {
969
+ omp_capture_exception(ex);
970
+ }
890
971
  }
891
- scan_list_func(i, ik, *qres);
972
+ } else {
973
+ FAISS_THROW_FMT(
974
+ "parallel_mode %d not supported\n", parallel_mode);
892
975
  }
893
- } else {
894
- FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
895
- }
896
- if (parallel_mode == 0) {
897
- pres.finalize();
898
- } else {
976
+ if (parallel_mode == 0) {
977
+ pres.finalize();
978
+ } else {
899
979
  #pragma omp barrier
900
980
  #pragma omp single
901
- RangeSearchPartialResult::merge(all_pres, false);
981
+ RangeSearchPartialResult::merge(all_pres, false);
902
982
  #pragma omp barrier
983
+ }
984
+ } catch (...) {
985
+ omp_capture_exception(ex);
903
986
  }
904
987
  }
905
988
 
906
- if (interrupt) {
907
- if (!exception_string.empty()) {
908
- FAISS_THROW_FMT(
909
- "search interrupted with: %s", exception_string.c_str());
910
- } else {
911
- FAISS_THROW_MSG("computation interrupted");
912
- }
913
- }
989
+ omp_rethrow_if_exception(ex);
914
990
 
915
991
  if (stats == nullptr) {
916
992
  stats = &indexIVF_stats;
@@ -931,26 +1007,31 @@ void IndexIVF::search1(
931
1007
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
932
1008
  quantizer_params = params->quantizer_params;
933
1009
  }
934
- const size_t nprobe =
1010
+ const size_t cur_nprobe =
935
1011
  std::min(nlist, params ? params->nprobe : this->nprobe);
936
1012
  size_t nx = 1;
937
- std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
938
- std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
1013
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
1014
+ std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
939
1015
 
940
1016
  double t0 = getmillisecs();
941
1017
  quantizer->search(
942
- nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
1018
+ nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
943
1019
  indexIVF_stats.quantization_time += getmillisecs() - t0;
944
1020
 
945
1021
  t0 = getmillisecs();
946
- invlists->prefetch_lists(keys.get(), nx * nprobe);
1022
+ invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
947
1023
 
948
1024
  std::unique_ptr<InvertedListScanner> scanner(
949
1025
  get_InvertedListScanner(false, nullptr, params));
950
1026
  scanner->set_query(x);
951
1027
 
952
- for (idx_t i = 0; i < nprobe; i++) {
1028
+ for (size_t i = 0; i < cur_nprobe; i++) {
953
1029
  idx_t key = keys[i];
1030
+ FAISS_THROW_IF_NOT_FMT(
1031
+ key < (idx_t)nlist,
1032
+ "Invalid key=%" PRId64 " nlist=%zd\n",
1033
+ key,
1034
+ nlist);
954
1035
  if (key < 0 || invlists->is_empty(key)) {
955
1036
  continue;
956
1037
  }
@@ -981,11 +1062,11 @@ void IndexIVF::reconstruct(idx_t key, float* recons) const {
981
1062
  void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
982
1063
  FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
983
1064
 
984
- for (idx_t list_no = 0; list_no < nlist; list_no++) {
1065
+ for (size_t list_no = 0; list_no < nlist; list_no++) {
985
1066
  size_t list_size = invlists->list_size(list_no);
986
1067
  ScopedIds idlist(invlists, list_no);
987
1068
 
988
- for (idx_t offset = 0; offset < list_size; offset++) {
1069
+ for (size_t offset = 0; offset < list_size; offset++) {
989
1070
  idx_t id = idlist[offset];
990
1071
  if (!(id >= i0 && id < i0 + ni)) {
991
1072
  continue;
@@ -1046,16 +1127,16 @@ void IndexIVF::search_and_reconstruct(
1046
1127
  params = dynamic_cast<const IVFSearchParameters*>(params_in);
1047
1128
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1048
1129
  }
1049
- const size_t nprobe =
1130
+ const size_t cur_nprobe =
1050
1131
  std::min(nlist, params ? params->nprobe : this->nprobe);
1051
- FAISS_THROW_IF_NOT(nprobe > 0);
1132
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
1052
1133
 
1053
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
1054
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1134
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
1135
+ std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
1055
1136
 
1056
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
1137
+ quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
1057
1138
 
1058
- invlists->prefetch_lists(idx.get(), n * nprobe);
1139
+ invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
1059
1140
 
1060
1141
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
1061
1142
  // and offset into `codes` for reconstruction
@@ -1077,8 +1158,8 @@ void IndexIVF::search_and_reconstruct(
1077
1158
  // Fill with NaNs
1078
1159
  memset(reconstructed, -1, sizeof(*reconstructed) * d);
1079
1160
  } else {
1080
- int list_no = lo_listno(key);
1081
- int offset = lo_offset(key);
1161
+ size_t list_no = lo_listno(key);
1162
+ size_t offset = lo_offset(key);
1082
1163
 
1083
1164
  // Update label to the actual id
1084
1165
  labels[ij] = invlists->get_single_id(list_no, offset);
@@ -1102,16 +1183,16 @@ void IndexIVF::search_and_return_codes(
1102
1183
  params = dynamic_cast<const IVFSearchParameters*>(params_in);
1103
1184
  FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1104
1185
  }
1105
- const size_t nprobe =
1186
+ const size_t cur_nprobe =
1106
1187
  std::min(nlist, params ? params->nprobe : this->nprobe);
1107
- FAISS_THROW_IF_NOT(nprobe > 0);
1188
+ FAISS_THROW_IF_NOT(cur_nprobe > 0);
1108
1189
 
1109
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
1110
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1190
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
1191
+ std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
1111
1192
 
1112
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
1193
+ quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
1113
1194
 
1114
- invlists->prefetch_lists(idx.get(), n * nprobe);
1195
+ invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
1115
1196
 
1116
1197
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
1117
1198
  // and offset into `codes` for reconstruction
@@ -1140,8 +1221,8 @@ void IndexIVF::search_and_return_codes(
1140
1221
  // Fill with 0xff
1141
1222
  memset(code1, -1, code_size_1);
1142
1223
  } else {
1143
- int list_no = lo_listno(key);
1144
- int offset = lo_offset(key);
1224
+ size_t list_no = lo_listno(key);
1225
+ size_t offset = lo_offset(key);
1145
1226
  const uint8_t* cc = invlists->get_single_code(list_no, offset);
1146
1227
 
1147
1228
  labels[ij] = invlists->get_single_id(list_no, offset);
@@ -1180,7 +1261,8 @@ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
1180
1261
  IDSelectorArray sel(n, new_ids);
1181
1262
  size_t nremove = remove_ids(sel);
1182
1263
  FAISS_THROW_IF_NOT_MSG(
1183
- nremove == n, "did not find all entries to remove");
1264
+ nremove == static_cast<size_t>(n),
1265
+ "did not find all entries to remove");
1184
1266
  add_with_ids(n, x, new_ids);
1185
1267
  return;
1186
1268
  }
@@ -1242,7 +1324,7 @@ idx_t IndexIVF::train_encoder_num_vectors() const {
1242
1324
  void IndexIVF::train_encoder(
1243
1325
  idx_t /*n*/,
1244
1326
  const float* /*x*/,
1245
- const idx_t* assign) {
1327
+ const idx_t* /*assign*/) {
1246
1328
  // does nothing by default
1247
1329
  if (verbose) {
1248
1330
  printf("IndexIVF: no residual training\n");
@@ -1385,11 +1467,19 @@ size_t InvertedListScanner::iterate_codes(
1385
1467
  size_t nup = 0;
1386
1468
  list_size = 0;
1387
1469
 
1470
+ const bool has_cb = it->has_search_callbacks_;
1471
+
1388
1472
  if (!keep_max) {
1389
1473
  for (; it->is_available(); it->next()) {
1390
1474
  auto id_and_codes = it->get_id_and_codes();
1391
1475
  float dis = distance_to_code(id_and_codes.second);
1476
+ if (has_cb) {
1477
+ it->on_distance_computed(id_and_codes.first, dis);
1478
+ }
1392
1479
  if (dis < simi[0]) {
1480
+ if (has_cb) {
1481
+ it->on_heap_changed(id_and_codes.first, idxi[0]);
1482
+ }
1393
1483
  maxheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
1394
1484
  nup++;
1395
1485
  }
@@ -1399,7 +1489,13 @@ size_t InvertedListScanner::iterate_codes(
1399
1489
  for (; it->is_available(); it->next()) {
1400
1490
  auto id_and_codes = it->get_id_and_codes();
1401
1491
  float dis = distance_to_code(id_and_codes.second);
1492
+ if (has_cb) {
1493
+ it->on_distance_computed(id_and_codes.first, dis);
1494
+ }
1402
1495
  if (dis > simi[0]) {
1496
+ if (has_cb) {
1497
+ it->on_heap_changed(id_and_codes.first, idxi[0]);
1498
+ }
1403
1499
  minheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
1404
1500
  nup++;
1405
1501
  }
@@ -1419,10 +1515,14 @@ void InvertedListScanner::scan_codes_range(
1419
1515
  using C = CMax<float, idx_t>;
1420
1516
  RangeResultHandler<C, false> handler(&res, radius);
1421
1517
  scan_codes(list_size, codes, ids, handler);
1518
+ res.stats.scan_cnt += handler.stats.scan_cnt;
1519
+ res.stats.nheap_updates += handler.stats.nheap_updates;
1422
1520
  } else {
1423
1521
  using C = CMin<float, idx_t>;
1424
1522
  RangeResultHandler<C, false> handler(&res, radius);
1425
1523
  scan_codes(list_size, codes, ids, handler);
1524
+ res.stats.scan_cnt += handler.stats.scan_cnt;
1525
+ res.stats.nheap_updates += handler.stats.nheap_updates;
1426
1526
  }
1427
1527
  }
1428
1528
 
@@ -1431,6 +1531,7 @@ void InvertedListScanner::iterate_codes_range(
1431
1531
  float radius,
1432
1532
  RangeQueryResult& res,
1433
1533
  size_t& list_size) const {
1534
+ size_t nup = 0;
1434
1535
  list_size = 0;
1435
1536
  for (; it->is_available(); it->next()) {
1436
1537
  auto id_and_codes = it->get_id_and_codes();
@@ -1440,9 +1541,11 @@ void InvertedListScanner::iterate_codes_range(
1440
1541
  : dis > radius; // TODO templatize to remove this test
1441
1542
  if (keep) {
1442
1543
  res.add(dis, id_and_codes.first);
1544
+ nup++;
1443
1545
  }
1444
1546
  list_size++;
1445
1547
  }
1548
+ res.stats.nheap_updates += nup;
1446
1549
  }
1447
1550
 
1448
1551
  } // namespace faiss