faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -12,6 +12,7 @@
12
12
  #include <faiss/impl/AuxIndexStructures.h>
13
13
  #include <faiss/impl/FaissAssert.h>
14
14
  #include <faiss/impl/ResultHandler.h>
15
+ #include <faiss/impl/simd_dispatch.h>
15
16
  #include <faiss/utils/Heap.h>
16
17
  #include <faiss/utils/distances.h>
17
18
  #include <faiss/utils/extra_distances.h>
@@ -19,12 +20,11 @@
19
20
  #include <faiss/utils/sorting.h>
20
21
  #include <omp.h>
21
22
  #include <cstring>
22
- #include <numeric>
23
23
 
24
24
  namespace faiss {
25
25
 
26
- IndexFlat::IndexFlat(idx_t d, MetricType metric)
27
- : IndexFlatCodes(sizeof(float) * d, d, metric) {}
26
+ IndexFlat::IndexFlat(idx_t d_, MetricType metric)
27
+ : IndexFlatCodes(sizeof(float) * d_, d_, metric) {}
28
28
 
29
29
  void IndexFlat::search(
30
30
  idx_t n,
@@ -44,7 +44,6 @@ void IndexFlat::search(
44
44
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
45
  knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
46
46
  } else {
47
- FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
48
47
  knn_extra_metrics(
49
48
  x,
50
49
  get_xb(),
@@ -55,7 +54,8 @@ void IndexFlat::search(
55
54
  metric_arg,
56
55
  k,
57
56
  distances,
58
- labels);
57
+ labels,
58
+ sel);
59
59
  }
60
60
  }
61
61
 
@@ -65,6 +65,7 @@ void IndexFlat::range_search(
65
65
  float radius,
66
66
  RangeSearchResult* result,
67
67
  const SearchParameters* params) const {
68
+ FAISS_THROW_IF_NOT_MSG(result, "RangeSearchResult object must not be null");
68
69
  IDSelector* sel = params ? params->sel : nullptr;
69
70
 
70
71
  switch (metric_type) {
@@ -86,6 +87,7 @@ void IndexFlat::compute_distance_subset(
86
87
  idx_t k,
87
88
  float* distances,
88
89
  const idx_t* labels) const {
90
+ FAISS_THROW_IF_NOT(k > 0);
89
91
  switch (metric_type) {
90
92
  case METRIC_INNER_PRODUCT:
91
93
  fvec_inner_products_by_idx(distances, x, get_xb(), labels, d, n, k);
@@ -100,6 +102,7 @@ void IndexFlat::compute_distance_subset(
100
102
 
101
103
  namespace {
102
104
 
105
+ template <SIMDLevel SL>
103
106
  struct FlatL2Dis : FlatCodesDistanceComputer {
104
107
  size_t d;
105
108
  idx_t nb;
@@ -109,7 +112,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
109
112
 
110
113
  float distance_to_code(const uint8_t* code) final {
111
114
  ndis++;
112
- return fvec_L2sqr(q, (float*)code, d);
115
+ return fvec_L2sqr<SL>(q, (float*)code, d);
113
116
  }
114
117
 
115
118
  float partial_dot_product(
@@ -117,19 +120,19 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
117
120
  const uint32_t offset,
118
121
  const uint32_t num_components) final override {
119
122
  npartial_dot_products++;
120
- return fvec_inner_product(
123
+ return fvec_inner_product<SL>(
121
124
  q + offset, b + i * d + offset, num_components);
122
125
  }
123
126
 
124
127
  float symmetric_dis(idx_t i, idx_t j) override {
125
- return fvec_L2sqr(b + j * d, b + i * d, d);
128
+ return fvec_L2sqr<SL>(b + j * d, b + i * d, d);
126
129
  }
127
130
 
128
- explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
131
+ explicit FlatL2Dis(const IndexFlat& storage, const float* q_ = nullptr)
129
132
  : FlatCodesDistanceComputer(
130
133
  storage.codes.data(),
131
134
  storage.code_size,
132
- q),
135
+ q_),
133
136
  d(storage.d),
134
137
  nb(storage.ntotal),
135
138
  b(storage.get_xb()),
@@ -166,7 +169,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
166
169
  float dp1 = 0;
167
170
  float dp2 = 0;
168
171
  float dp3 = 0;
169
- fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
172
+ fvec_L2sqr_batch_4<SL>(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
170
173
  dis0 = dp0;
171
174
  dis1 = dp1;
172
175
  dis2 = dp2;
@@ -200,7 +203,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
200
203
  float dp1_ = 0;
201
204
  float dp2_ = 0;
202
205
  float dp3_ = 0;
203
- fvec_inner_product_batch_4(
206
+ fvec_inner_product_batch_4<SL>(
204
207
  q + offset,
205
208
  y0 + offset,
206
209
  y1 + offset,
@@ -218,6 +221,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
218
221
  }
219
222
  };
220
223
 
224
+ template <SIMDLevel SL>
221
225
  struct FlatIPDis : FlatCodesDistanceComputer {
222
226
  size_t d;
223
227
  idx_t nb;
@@ -226,21 +230,21 @@ struct FlatIPDis : FlatCodesDistanceComputer {
226
230
  size_t ndis;
227
231
 
228
232
  float symmetric_dis(idx_t i, idx_t j) final override {
229
- return fvec_inner_product(b + j * d, b + i * d, d);
233
+ return fvec_inner_product<SL>(b + j * d, b + i * d, d);
230
234
  }
231
235
 
232
236
  float distance_to_code(const uint8_t* code) final override {
233
237
  ndis++;
234
- return fvec_inner_product(q, (const float*)code, d);
238
+ return fvec_inner_product<SL>(q, (const float*)code, d);
235
239
  }
236
240
 
237
- explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
241
+ explicit FlatIPDis(const IndexFlat& storage, const float* q_in = nullptr)
238
242
  : FlatCodesDistanceComputer(
239
243
  storage.codes.data(),
240
244
  storage.code_size),
241
245
  d(storage.d),
242
246
  nb(storage.ntotal),
243
- q(q),
247
+ q(q_in),
244
248
  b(storage.get_xb()),
245
249
  ndis(0) {}
246
250
 
@@ -274,7 +278,8 @@ struct FlatIPDis : FlatCodesDistanceComputer {
274
278
  float dp1 = 0;
275
279
  float dp2 = 0;
276
280
  float dp3 = 0;
277
- fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
281
+ fvec_inner_product_batch_4<SL>(
282
+ q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
278
283
  dis0 = dp0;
279
284
  dis1 = dp1;
280
285
  dis2 = dp2;
@@ -285,14 +290,16 @@ struct FlatIPDis : FlatCodesDistanceComputer {
285
290
  } // namespace
286
291
 
287
292
  FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
293
+ FlatCodesDistanceComputer* dc = nullptr;
288
294
  if (metric_type == METRIC_L2) {
289
- return new FlatL2Dis(*this);
295
+ with_simd_level([&]<SIMDLevel SL>() { dc = new FlatL2Dis<SL>(*this); });
290
296
  } else if (metric_type == METRIC_INNER_PRODUCT) {
291
- return new FlatIPDis(*this);
297
+ with_simd_level([&]<SIMDLevel SL>() { dc = new FlatIPDis<SL>(*this); });
292
298
  } else {
293
- return get_extra_distance_computer(
299
+ dc = get_extra_distance_computer(
294
300
  d, metric_type, metric_arg, ntotal, get_xb());
295
301
  }
302
+ return dc;
296
303
  }
297
304
 
298
305
  void IndexFlat::reconstruct(idx_t key, float* recons) const {
@@ -317,6 +324,7 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
317
324
  ***************************************************/
318
325
 
319
326
  namespace {
327
+ template <SIMDLevel SL>
320
328
  struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
321
329
  size_t d;
322
330
  idx_t nb;
@@ -329,7 +337,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
329
337
 
330
338
  float distance_to_code(const uint8_t* code) final override {
331
339
  ndis++;
332
- return fvec_L2sqr(q, (float*)code, d);
340
+ return fvec_L2sqr<SL>(q, (float*)code, d);
333
341
  }
334
342
 
335
343
  float operator()(const idx_t i) final override {
@@ -337,7 +345,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
337
345
  reinterpret_cast<const float*>(codes + i * code_size);
338
346
 
339
347
  prefetch_L2(l2norms + i);
340
- const float dp0 = fvec_inner_product(q, y, d);
348
+ const float dp0 = fvec_inner_product<SL>(q, y, d);
341
349
  return query_l2norm + l2norms[i] - 2 * dp0;
342
350
  }
343
351
 
@@ -349,19 +357,19 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
349
357
 
350
358
  prefetch_L2(l2norms + i);
351
359
  prefetch_L2(l2norms + j);
352
- const float dp0 = fvec_inner_product(yi, yj, d);
360
+ const float dp0 = fvec_inner_product<SL>(yi, yj, d);
353
361
  return l2norms[i] + l2norms[j] - 2 * dp0;
354
362
  }
355
363
 
356
364
  explicit FlatL2WithNormsDis(
357
365
  const IndexFlatL2& storage,
358
- const float* q = nullptr)
366
+ const float* q_in = nullptr)
359
367
  : FlatCodesDistanceComputer(
360
368
  storage.codes.data(),
361
369
  storage.code_size),
362
370
  d(storage.d),
363
371
  nb(storage.ntotal),
364
- q(q),
372
+ q(q_in),
365
373
  b(storage.get_xb()),
366
374
  ndis(0),
367
375
  l2norms(storage.cached_l2norms.data()),
@@ -369,7 +377,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
369
377
 
370
378
  void set_query(const float* x) override {
371
379
  q = x;
372
- query_l2norm = fvec_norm_L2sqr(q, d);
380
+ query_l2norm = fvec_norm_L2sqr<SL>(q, d);
373
381
  }
374
382
 
375
383
  // compute four distances
@@ -403,7 +411,8 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
403
411
  float dp1 = 0;
404
412
  float dp2 = 0;
405
413
  float dp3 = 0;
406
- fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
414
+ fvec_inner_product_batch_4<SL>(
415
+ q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
407
416
  dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
408
417
  dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
409
418
  dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
@@ -430,7 +439,11 @@ void IndexFlatL2::clear_l2norms() {
430
439
  FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
431
440
  if (metric_type == METRIC_L2) {
432
441
  if (!cached_l2norms.empty()) {
433
- return new FlatL2WithNormsDis(*this);
442
+ FlatCodesDistanceComputer* dc = nullptr;
443
+ with_simd_level([&]<SIMDLevel SL>() {
444
+ dc = new FlatL2WithNormsDis<SL>(*this);
445
+ });
446
+ return dc;
434
447
  }
435
448
  }
436
449
 
@@ -441,8 +454,8 @@ FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
441
454
  * IndexFlat1D
442
455
  ***************************************************/
443
456
 
444
- IndexFlat1D::IndexFlat1D(bool continuous_update)
445
- : IndexFlatL2(1), continuous_update(continuous_update) {}
457
+ IndexFlat1D::IndexFlat1D(bool continuous_update_in)
458
+ : IndexFlatL2(1), continuous_update(continuous_update_in) {}
446
459
 
447
460
  /// if not continuous_update, call this between the last add and
448
461
  /// the first search
@@ -478,7 +491,8 @@ void IndexFlat1D::search(
478
491
  !params, "search params not supported for this index");
479
492
  FAISS_THROW_IF_NOT(k > 0);
480
493
  FAISS_THROW_IF_NOT_MSG(
481
- perm.size() == ntotal, "Call update_permutation before search");
494
+ perm.size() == static_cast<size_t>(ntotal),
495
+ "Call update_permutation before search");
482
496
  const float* xb = get_xb();
483
497
 
484
498
  #pragma omp parallel for if (n > 10000)
@@ -583,7 +597,17 @@ void IndexFlat1D::search(
583
597
 
584
598
  namespace {
585
599
 
586
- template <bool use_radius, typename BlockHandler>
600
+ template <typename Fn>
601
+ inline auto dispatch_metric_compare(MetricType metric, Fn&& fn) {
602
+ if (is_similarity_metric(metric)) {
603
+ using C = CMin<float, int64_t>;
604
+ return fn.template operator()<C>();
605
+ }
606
+ using C = CMax<float, int64_t>;
607
+ return fn.template operator()<C>();
608
+ }
609
+
610
+ template <bool use_radius, typename C, typename BlockHandler>
587
611
  inline void flat_pano_search_core(
588
612
  const IndexFlatPanorama& index,
589
613
  BlockHandler& handler,
@@ -603,9 +627,11 @@ inline void flat_pano_search_core(
603
627
  {
604
628
  SingleResultHandler res(handler);
605
629
 
606
- std::vector<float> query_cum_norms(index.n_levels + 1);
607
- std::vector<float> exact_distances(index.batch_size);
630
+ std::vector<float> query_cum_norms(index.pano.n_levels + 1);
608
631
  std::vector<uint32_t> active_indices(index.batch_size);
632
+ std::vector<uint8_t> active_byteset(index.batch_size);
633
+ std::vector<float> exact_distances(index.batch_size);
634
+ std::vector<float> dot_buffer(index.batch_size);
609
635
 
610
636
  #pragma omp for
611
637
  for (int64_t i = 0; i < n; i++) {
@@ -627,22 +653,25 @@ inline void flat_pano_search_core(
627
653
  threshold = res.heap_dis[0];
628
654
  }
629
655
 
630
- size_t num_active =
631
- index.pano
632
- .progressive_filter_batch<CMax<float, int64_t>>(
633
- index.codes.data(),
634
- index.cum_sums.data(),
635
- xi,
636
- query_cum_norms.data(),
637
- batch_no,
638
- index.ntotal,
639
- sel,
640
- nullptr,
641
- use_sel,
642
- active_indices,
643
- exact_distances,
644
- threshold,
645
- local_stats);
656
+ size_t num_active = with_metric_type(
657
+ index.metric_type, [&]<MetricType M>() {
658
+ return index.pano.progressive_filter_batch<C, M>(
659
+ index.codes.data(),
660
+ index.cum_sums.data(),
661
+ xi,
662
+ query_cum_norms.data(),
663
+ batch_no,
664
+ index.ntotal,
665
+ sel,
666
+ nullptr,
667
+ use_sel,
668
+ active_indices,
669
+ active_byteset,
670
+ exact_distances,
671
+ dot_buffer,
672
+ threshold,
673
+ local_stats);
674
+ });
646
675
 
647
676
  for (size_t j = 0; j < num_active; j++) {
648
677
  res.add_result(
@@ -669,7 +698,7 @@ void IndexFlatPanorama::add(idx_t n, const float* x) {
669
698
  size_t num_batches = (ntotal + batch_size - 1) / batch_size;
670
699
 
671
700
  codes.resize(num_batches * batch_size * code_size);
672
- cum_sums.resize(num_batches * batch_size * (n_levels + 1));
701
+ cum_sums.resize(num_batches * batch_size * (pano.n_levels + 1));
673
702
 
674
703
  const uint8_t* code = reinterpret_cast<const uint8_t*>(x);
675
704
  pano.copy_codes_to_level_layout(codes.data(), offset, n, code);
@@ -684,12 +713,13 @@ void IndexFlatPanorama::search(
684
713
  idx_t* labels,
685
714
  const SearchParameters* params) const {
686
715
  FAISS_THROW_IF_NOT(k > 0);
687
- FAISS_THROW_IF_NOT(batch_size >= k);
716
+ FAISS_THROW_IF_NOT(batch_size >= static_cast<size_t>(k));
688
717
 
689
- HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
690
- size_t(n), distances, labels, size_t(k), nullptr);
691
-
692
- flat_pano_search_core<false>(*this, handler, n, x, 0.0f, params);
718
+ dispatch_metric_compare(metric_type, [&]<typename C>() {
719
+ HeapBlockResultHandler<C, false> handler(
720
+ size_t(n), distances, labels, size_t(k), nullptr);
721
+ flat_pano_search_core<false, C>(*this, handler, n, x, 0.0f, params);
722
+ });
693
723
  }
694
724
 
695
725
  void IndexFlatPanorama::range_search(
@@ -698,10 +728,11 @@ void IndexFlatPanorama::range_search(
698
728
  float radius,
699
729
  RangeSearchResult* result,
700
730
  const SearchParameters* params) const {
701
- RangeSearchBlockResultHandler<CMax<float, int64_t>, false> handler(
702
- result, radius, nullptr);
703
-
704
- flat_pano_search_core<true>(*this, handler, n, x, radius, params);
731
+ dispatch_metric_compare(metric_type, [&]<typename C>() {
732
+ RangeSearchBlockResultHandler<C, false> handler(
733
+ result, radius, nullptr);
734
+ flat_pano_search_core<true, C>(*this, handler, n, x, radius, params);
735
+ });
705
736
  }
706
737
 
707
738
  void IndexFlatPanorama::reset() {
@@ -740,7 +771,7 @@ size_t IndexFlatPanorama::remove_ids(const IDSelector& sel) {
740
771
  ntotal = j;
741
772
  size_t num_batches = (ntotal + batch_size - 1) / batch_size;
742
773
  codes.resize(num_batches * batch_size * code_size);
743
- cum_sums.resize(num_batches * batch_size * (n_levels + 1));
774
+ cum_sums.resize(num_batches * batch_size * (pano.n_levels + 1));
744
775
  }
745
776
  return nremove;
746
777
  }
@@ -790,103 +821,136 @@ void IndexFlatPanorama::search_subset(
790
821
  idx_t k,
791
822
  float* distances,
792
823
  idx_t* labels) const {
793
- using SingleResultHandler =
794
- HeapBlockResultHandler<CMax<float, int64_t>, false>::
795
- SingleResultHandler;
796
- HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
797
- size_t(n), distances, labels, size_t(k), nullptr);
798
-
799
- FAISS_THROW_IF_NOT(k > 0);
800
- FAISS_THROW_IF_NOT(batch_size == 1);
801
-
802
- [[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
824
+ with_simd_level([&]<SIMDLevel SL>() {
825
+ with_metric_type(metric_type, [&]<MetricType M>() {
826
+ constexpr bool is_sim = is_similarity_metric(M);
827
+ using C = std::conditional_t<
828
+ is_sim,
829
+ CMin<float, int64_t>,
830
+ CMax<float, int64_t>>;
831
+ using SingleResultHandler =
832
+ typename HeapBlockResultHandler<C, false>::
833
+ SingleResultHandler;
834
+ HeapBlockResultHandler<C, false> handler(
835
+ size_t(n), distances, labels, size_t(k), nullptr);
836
+
837
+ FAISS_THROW_IF_NOT(k > 0);
838
+ FAISS_THROW_IF_NOT(batch_size == 1);
839
+
840
+ [[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
803
841
 
804
842
  #pragma omp parallel num_threads(nt)
805
- {
806
- SingleResultHandler res(handler);
807
-
808
- std::vector<float> query_cum_norms(n_levels + 1);
809
-
810
- // Panorama's optimized point-wise refinement (Algorithm 2):
811
- // Batch-wise Panorama, as implemented in Panorama.h, incurs overhead
812
- // from maintaining active_indices and exact_distances. This optimized
813
- // implementation has minimal overhead and is thus preferred for
814
- // IndexRefine's use case.
815
- // 1. Initialize exact distance as ||y||^2 + ||x||^2.
816
- // 2. For each level, refine distance incrementally:
817
- // - Compute dot product for current level: exact_dist -= 2*<x,y>.
818
- // - Use Cauchy-Schwarz bound on remaining levels to get lower bound.
819
- // - If there are less than k points in the heap, add the point to
820
- // the heap.
821
- // - Else, prune if lower bound exceeds k-th best distance.
822
- // 3. After all levels, update heap if the point survived.
843
+ {
844
+ SingleResultHandler res(handler);
845
+
846
+ std::vector<float> query_cum_norms(pano.n_levels + 1);
847
+
848
+ // Panorama's optimized point-wise refinement (Algorithm 2):
849
+ // Batch-wise Panorama, as implemented in Panorama.h, incurs
850
+ // overhead from maintaining active_indices and exact_distances.
851
+ // This optimized implementation has minimal overhead and is
852
+ // thus preferred for IndexRefine's use case.
853
+ // 1. Initialize exact distance as ||y||^2 + ||x||^2.
854
+ // 2. For each level, refine distance incrementally:
855
+ // - Compute dot product for current level: exact_dist -=
856
+ // 2*<x,y>.
857
+ // - Use Cauchy-Schwarz bound on remaining levels to get
858
+ // lower bound.
859
+ // - If there are less than k points in the heap, add the
860
+ // point to the heap.
861
+ // - Else, prune if lower bound exceeds k-th best distance.
862
+ // 3. After all levels, update heap if the point survived.
823
863
  #pragma omp for
824
- for (idx_t i = 0; i < n; i++) {
825
- const idx_t* __restrict idsi = base_labels + i * k_base;
826
- const float* xi = x + i * d;
827
-
828
- PanoramaStats local_stats;
829
- local_stats.reset();
830
-
831
- pano.compute_query_cum_sums(xi, query_cum_norms.data());
832
- float query_cum_norm = query_cum_norms[0] * query_cum_norms[0];
833
-
834
- res.begin(i);
835
-
836
- for (size_t j = 0; j < k_base; j++) {
837
- idx_t idx = idsi[j];
838
-
839
- if (idx < 0) {
840
- continue;
841
- }
842
-
843
- size_t cum_sum_offset = (n_levels + 1) * idx;
844
- float cum_sum = cum_sums[cum_sum_offset];
845
- float exact_distance = cum_sum * cum_sum + query_cum_norm;
846
- cum_sum_offset++;
847
-
848
- const float* x_ptr = xi;
849
- const float* p_ptr =
850
- reinterpret_cast<const float*>(codes.data()) + d * idx;
851
-
852
- local_stats.total_dims += d;
853
-
854
- bool pruned = false;
855
- for (size_t level = 0; level < n_levels; level++) {
856
- local_stats.total_dims_scanned += pano.level_width_floats;
857
-
858
- // Refine distance
859
- size_t actual_level_width = std::min(
860
- pano.level_width_floats,
861
- d - level * pano.level_width_floats);
862
- float dot_product = fvec_inner_product(
863
- x_ptr, p_ptr, actual_level_width);
864
- exact_distance -= 2 * dot_product;
865
-
866
- float cum_sum = cum_sums[cum_sum_offset];
867
- float cauchy_schwarz_bound =
868
- 2.0f * cum_sum * query_cum_norms[level + 1];
869
- float lower_bound = exact_distance - cauchy_schwarz_bound;
870
-
871
- // Prune using Cauchy-Schwarz bound
872
- if (lower_bound > res.heap_dis[0]) {
873
- pruned = true;
874
- break;
864
+ for (idx_t i = 0; i < n; i++) {
865
+ const idx_t* __restrict idsi = base_labels + i * k_base;
866
+ const float* xi = x + i * d;
867
+
868
+ PanoramaStats local_stats;
869
+ local_stats.reset();
870
+
871
+ pano.compute_query_cum_sums(xi, query_cum_norms.data());
872
+ float query_cum_norm =
873
+ query_cum_norms[0] * query_cum_norms[0];
874
+
875
+ res.begin(i);
876
+
877
+ for (idx_t j = 0; j < k_base; j++) {
878
+ idx_t idx = idsi[j];
879
+
880
+ if (idx < 0) {
881
+ continue;
882
+ }
883
+
884
+ size_t cum_sum_offset = (pano.n_levels + 1) * idx;
885
+ float cum_sum = cum_sums[cum_sum_offset];
886
+ float exact_distance = 0.0f;
887
+ if constexpr (!is_sim) {
888
+ exact_distance = cum_sum * cum_sum + query_cum_norm;
889
+ }
890
+ cum_sum_offset++;
891
+
892
+ const float* x_ptr = xi;
893
+ const float* p_ptr =
894
+ reinterpret_cast<const float*>(codes.data()) +
895
+ d * idx;
896
+
897
+ local_stats.total_dims += d;
898
+
899
+ bool pruned = false;
900
+ for (size_t level = 0; level < pano.n_levels; level++) {
901
+ local_stats.total_dims_scanned +=
902
+ pano.level_width_floats;
903
+
904
+ // Refine distance
905
+ size_t actual_level_width = std::min(
906
+ pano.level_width_floats,
907
+ d - level * pano.level_width_floats);
908
+ float dot_product = fvec_inner_product<SL>(
909
+ x_ptr, p_ptr, actual_level_width);
910
+ if constexpr (is_sim) {
911
+ exact_distance += dot_product;
912
+ } else {
913
+ exact_distance -= 2 * dot_product;
914
+ }
915
+
916
+ float level_cum_sum = cum_sums[cum_sum_offset];
917
+ float cauchy_schwarz_bound;
918
+ if constexpr (is_sim) {
919
+ cauchy_schwarz_bound = -level_cum_sum *
920
+ query_cum_norms[level + 1];
921
+ } else {
922
+ cauchy_schwarz_bound = 2.0f * level_cum_sum *
923
+ query_cum_norms[level + 1];
924
+ }
925
+ float bound = exact_distance - cauchy_schwarz_bound;
926
+
927
+ // Prune using Cauchy-Schwarz bound
928
+ bool should_prune = false;
929
+ if constexpr (is_sim) {
930
+ should_prune = bound < res.heap_dis[0];
931
+ } else {
932
+ should_prune = bound > res.heap_dis[0];
933
+ }
934
+ if (should_prune) {
935
+ pruned = true;
936
+ break;
937
+ }
938
+
939
+ cum_sum_offset++;
940
+ x_ptr += pano.level_width_floats;
941
+ p_ptr += pano.level_width_floats;
942
+ }
943
+
944
+ if (!pruned) {
945
+ res.add_result(exact_distance, idx);
946
+ }
875
947
  }
876
948
 
877
- cum_sum_offset++;
878
- x_ptr += pano.level_width_floats;
879
- p_ptr += pano.level_width_floats;
880
- }
881
-
882
- if (!pruned) {
883
- res.add_result(exact_distance, idx);
949
+ res.end();
950
+ indexPanorama_stats.add(local_stats);
884
951
  }
885
952
  }
886
-
887
- res.end();
888
- indexPanorama_stats.add(local_stats);
889
- }
890
- }
953
+ });
954
+ });
891
955
  }
892
956
  } // namespace faiss