faiss 0.5.3 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +4 -4
- data/ext/faiss/index.cpp +63 -45
- data/ext/faiss/index_binary.cpp +37 -27
- data/ext/faiss/kmeans.cpp +9 -8
- data/ext/faiss/pca_matrix.cpp +9 -7
- data/ext/faiss/product_quantizer.cpp +13 -11
- data/ext/faiss/utils.cpp +4 -2
- data/ext/faiss/utils.h +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +214 -82
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +97 -249
- data/vendor/faiss/faiss/Clustering.h +18 -0
- data/vendor/faiss/faiss/IVFlib.cpp +67 -44
- data/vendor/faiss/faiss/Index.cpp +25 -12
- data/vendor/faiss/faiss/Index.h +26 -4
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
- data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
- data/vendor/faiss/faiss/IndexFastScan.h +35 -24
- data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
- data/vendor/faiss/faiss/IndexFlat.h +32 -14
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
- data/vendor/faiss/faiss/IndexHNSW.h +30 -14
- data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
- data/vendor/faiss/faiss/IndexIVF.h +47 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
- data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
- data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
- data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
- data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
- data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +13 -13
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +29 -6
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
- data/vendor/faiss/faiss/VectorTransform.h +39 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +55 -51
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
- data/vendor/faiss/faiss/impl/HNSW.h +21 -40
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
- data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
- data/vendor/faiss/faiss/impl/NSG.h +20 -10
- data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
- data/vendor/faiss/faiss/impl/Panorama.h +265 -78
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
- data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +115 -28
- data/vendor/faiss/faiss/index_io.h +53 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +507 -559
- data/vendor/faiss/faiss/utils/distances.h +118 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
- data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
- data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +21 -14
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +156 -42
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -8,14 +8,17 @@
|
|
|
8
8
|
#include <faiss/IndexIVFRaBitQFastScan.h>
|
|
9
9
|
|
|
10
10
|
#include <algorithm>
|
|
11
|
+
#include <array>
|
|
11
12
|
#include <cstdio>
|
|
13
|
+
#include <memory>
|
|
12
14
|
|
|
15
|
+
#include <faiss/impl/CodePackerRaBitQ.h>
|
|
13
16
|
#include <faiss/impl/FaissAssert.h>
|
|
14
|
-
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
15
17
|
#include <faiss/impl/RaBitQUtils.h>
|
|
16
18
|
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
17
|
-
#include <faiss/impl/
|
|
18
|
-
#include <faiss/impl/
|
|
19
|
+
#include <faiss/impl/ResultHandler.h>
|
|
20
|
+
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
|
|
21
|
+
#include <faiss/impl/fast_scan/fast_scan.h>
|
|
19
22
|
#include <faiss/invlists/BlockInvertedLists.h>
|
|
20
23
|
#include <faiss/utils/distances.h>
|
|
21
24
|
#include <faiss/utils/utils.h>
|
|
@@ -39,31 +42,38 @@ inline size_t roundup(size_t a, size_t b) {
|
|
|
39
42
|
IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan() = default;
|
|
40
43
|
|
|
41
44
|
IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
42
|
-
Index*
|
|
43
|
-
size_t
|
|
44
|
-
size_t
|
|
45
|
+
Index* quantizer_in,
|
|
46
|
+
size_t d_in,
|
|
47
|
+
size_t nlist_in,
|
|
45
48
|
MetricType metric,
|
|
46
|
-
int
|
|
47
|
-
bool
|
|
49
|
+
int bbs_in,
|
|
50
|
+
bool own_invlists_in,
|
|
48
51
|
uint8_t nb_bits)
|
|
49
|
-
: IndexIVFFastScan(
|
|
50
|
-
|
|
51
|
-
|
|
52
|
+
: IndexIVFFastScan(
|
|
53
|
+
quantizer_in,
|
|
54
|
+
d_in,
|
|
55
|
+
nlist_in,
|
|
56
|
+
0,
|
|
57
|
+
metric,
|
|
58
|
+
own_invlists_in),
|
|
59
|
+
rabitq(d_in, metric, nb_bits) {
|
|
60
|
+
FAISS_THROW_IF_NOT_MSG(d_in > 0, "Dimension must be positive");
|
|
52
61
|
FAISS_THROW_IF_NOT_MSG(
|
|
53
62
|
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
|
|
54
63
|
"RaBitQ only supports L2 and Inner Product metrics");
|
|
55
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
56
|
-
|
|
64
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
65
|
+
bbs_in % 32 == 0, "Batch size must be multiple of 32");
|
|
66
|
+
FAISS_THROW_IF_NOT_MSG(quantizer_in != nullptr, "Quantizer cannot be null");
|
|
57
67
|
|
|
58
68
|
by_residual = true;
|
|
59
69
|
qb = 8; // RaBitQ quantization bits
|
|
60
70
|
centered = false;
|
|
61
71
|
|
|
62
72
|
// FastScan-specific parameters: 4 bits per sub-quantizer
|
|
63
|
-
const size_t M_fastscan = (
|
|
73
|
+
const size_t M_fastscan = (d_in + 3) / 4;
|
|
64
74
|
constexpr size_t nbits_fastscan = 4;
|
|
65
75
|
|
|
66
|
-
this->bbs =
|
|
76
|
+
this->bbs = bbs_in;
|
|
67
77
|
this->fine_quantizer = &rabitq;
|
|
68
78
|
this->M = M_fastscan;
|
|
69
79
|
this->nbits = nbits_fastscan;
|
|
@@ -79,8 +89,6 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
79
89
|
if (own_invlists) {
|
|
80
90
|
replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
|
|
81
91
|
}
|
|
82
|
-
|
|
83
|
-
flat_storage.clear();
|
|
84
92
|
}
|
|
85
93
|
|
|
86
94
|
// Constructor that converts an existing IndexIVFRaBitQ to FastScan format
|
|
@@ -97,35 +105,11 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
97
105
|
rabitq(orig.rabitq) {}
|
|
98
106
|
|
|
99
107
|
size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
if (ex_bits == 0) {
|
|
103
|
-
// 1-bit: only SignBitFactors (8 bytes)
|
|
104
|
-
return sizeof(SignBitFactors);
|
|
105
|
-
} else {
|
|
106
|
-
// Multi-bit: SignBitFactorsWithError + ExtraBitsFactors + ex-codes
|
|
107
|
-
return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
|
|
108
|
-
(d * ex_bits + 7) / 8;
|
|
109
|
-
}
|
|
108
|
+
return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
|
|
110
109
|
}
|
|
111
110
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
const uint8_t* flat_codes,
|
|
115
|
-
idx_t start_global_idx) {
|
|
116
|
-
// Unified approach: always use flat_storage for both 1-bit and multi-bit
|
|
117
|
-
const size_t storage_size = compute_per_vector_storage_size();
|
|
118
|
-
flat_storage.resize((start_global_idx + n) * storage_size);
|
|
119
|
-
|
|
120
|
-
// Copy factors data directly to flat storage (no reordering needed)
|
|
121
|
-
const size_t bit_pattern_size = (d + 7) / 8;
|
|
122
|
-
for (idx_t i = 0; i < n; i++) {
|
|
123
|
-
const uint8_t* code = flat_codes + i * code_size;
|
|
124
|
-
const uint8_t* source_factors_ptr = code + bit_pattern_size;
|
|
125
|
-
uint8_t* storage =
|
|
126
|
-
flat_storage.data() + (start_global_idx + i) * storage_size;
|
|
127
|
-
memcpy(storage, source_factors_ptr, storage_size);
|
|
128
|
-
}
|
|
111
|
+
size_t IndexIVFRaBitQFastScan::fast_scan_code_size() const {
|
|
112
|
+
return (d + 7) / 8;
|
|
129
113
|
}
|
|
130
114
|
|
|
131
115
|
size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
|
|
@@ -133,6 +117,45 @@ size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
|
|
|
133
117
|
return code_size;
|
|
134
118
|
}
|
|
135
119
|
|
|
120
|
+
CodePacker* IndexIVFRaBitQFastScan::get_CodePacker() const {
|
|
121
|
+
return new CodePackerRaBitQ(M2, bbs, compute_per_vector_storage_size());
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
/*********************************************************
|
|
125
|
+
* postprocess_packed_codes: write auxiliary data into blocks
|
|
126
|
+
*********************************************************/
|
|
127
|
+
|
|
128
|
+
void IndexIVFRaBitQFastScan::postprocess_packed_codes(
|
|
129
|
+
idx_t list_no,
|
|
130
|
+
size_t list_offset,
|
|
131
|
+
size_t n_added,
|
|
132
|
+
const uint8_t* flat_codes) {
|
|
133
|
+
auto* bil = dynamic_cast<BlockInvertedLists*>(invlists);
|
|
134
|
+
FAISS_THROW_IF_NOT(bil);
|
|
135
|
+
|
|
136
|
+
uint8_t* block_data = bil->codes[list_no].data();
|
|
137
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
138
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
139
|
+
const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
|
|
140
|
+
const size_t full_block_size = get_block_stride();
|
|
141
|
+
|
|
142
|
+
for (size_t i = 0; i < n_added; i++) {
|
|
143
|
+
const uint8_t* src = flat_codes + i * code_size + bit_pattern_size;
|
|
144
|
+
uint8_t* dst = rabitq_utils::get_block_aux_ptr(
|
|
145
|
+
block_data,
|
|
146
|
+
list_offset + i,
|
|
147
|
+
bbs,
|
|
148
|
+
packed_block_size,
|
|
149
|
+
full_block_size,
|
|
150
|
+
storage_size);
|
|
151
|
+
memcpy(dst, src, storage_size);
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
/*********************************************************
|
|
156
|
+
* train_encoder
|
|
157
|
+
*********************************************************/
|
|
158
|
+
|
|
136
159
|
void IndexIVFRaBitQFastScan::train_encoder(
|
|
137
160
|
idx_t n,
|
|
138
161
|
const float* x,
|
|
@@ -183,7 +206,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
|
|
|
183
206
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
184
207
|
|
|
185
208
|
// Pack sign bits directly into FastScan format (inline)
|
|
186
|
-
for (size_t j = 0; j < d; j++) {
|
|
209
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
187
210
|
const float or_minus_c = xi[j] - centroid[j];
|
|
188
211
|
if (or_minus_c > 0.0f) {
|
|
189
212
|
rabitq_utils::set_bit_fastscan(fastscan_code, j);
|
|
@@ -212,7 +235,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
|
|
|
212
235
|
|
|
213
236
|
// Compute residual (needed for quantize_ex_bits)
|
|
214
237
|
std::vector<float> residual(d);
|
|
215
|
-
for (size_t j = 0; j < d; j++) {
|
|
238
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
216
239
|
residual[j] = xi[j] - centroid[j];
|
|
217
240
|
}
|
|
218
241
|
|
|
@@ -249,83 +272,133 @@ bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
|
|
|
249
272
|
return true;
|
|
250
273
|
}
|
|
251
274
|
|
|
275
|
+
// out[code] = base + sum of v_i for each set bit in code.
|
|
276
|
+
inline void write_subset_sum_lut(
|
|
277
|
+
float* out,
|
|
278
|
+
float base,
|
|
279
|
+
float v0,
|
|
280
|
+
float v1,
|
|
281
|
+
float v2,
|
|
282
|
+
float v3) {
|
|
283
|
+
out[0] = base;
|
|
284
|
+
out[1] = base + v0;
|
|
285
|
+
out[2] = base + v1;
|
|
286
|
+
out[3] = base + v0 + v1;
|
|
287
|
+
out[4] = base + v2;
|
|
288
|
+
out[5] = base + v0 + v2;
|
|
289
|
+
out[6] = base + v1 + v2;
|
|
290
|
+
out[7] = base + v0 + v1 + v2;
|
|
291
|
+
out[8] = base + v3;
|
|
292
|
+
out[9] = base + v0 + v3;
|
|
293
|
+
out[10] = base + v1 + v3;
|
|
294
|
+
out[11] = base + v0 + v1 + v3;
|
|
295
|
+
out[12] = base + v2 + v3;
|
|
296
|
+
out[13] = base + v0 + v2 + v3;
|
|
297
|
+
out[14] = base + v1 + v2 + v3;
|
|
298
|
+
out[15] = base + v0 + v1 + v2 + v3;
|
|
299
|
+
}
|
|
300
|
+
|
|
252
301
|
// Computes lookup table for residual vectors in RaBitQ FastScan format
|
|
253
302
|
void IndexIVFRaBitQFastScan::compute_residual_LUT(
|
|
254
|
-
const float*
|
|
303
|
+
const float* query,
|
|
304
|
+
idx_t centroid_id,
|
|
255
305
|
QueryFactorsData& query_factors,
|
|
256
306
|
float* lut_out,
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
307
|
+
uint8_t qb_param,
|
|
308
|
+
bool centered_param,
|
|
309
|
+
std::vector<float>& rotated_q,
|
|
310
|
+
std::vector<float>& centroid_buf) const {
|
|
311
|
+
const size_t d_val = static_cast<size_t>(d);
|
|
312
|
+
FAISS_THROW_IF_NOT(d_val > 0);
|
|
313
|
+
rotated_q.resize(d_val);
|
|
314
|
+
centroid_buf.resize(d_val);
|
|
315
|
+
std::vector<uint8_t> rotated_qq(d_val);
|
|
316
|
+
|
|
317
|
+
// Compute residual
|
|
318
|
+
quantizer->reconstruct(centroid_id, centroid_buf.data());
|
|
319
|
+
for (size_t i = 0; i < d_val; i++) {
|
|
320
|
+
rotated_q[i] = query[i] - centroid_buf[i];
|
|
321
|
+
}
|
|
262
322
|
|
|
263
|
-
//
|
|
323
|
+
// Compute query factors using shared utility
|
|
264
324
|
query_factors = rabitq_utils::compute_query_factors(
|
|
265
|
-
|
|
266
|
-
|
|
325
|
+
rotated_q.data(),
|
|
326
|
+
d_val,
|
|
267
327
|
nullptr,
|
|
268
|
-
|
|
269
|
-
|
|
328
|
+
qb_param,
|
|
329
|
+
centered_param,
|
|
270
330
|
metric_type,
|
|
271
331
|
rotated_q,
|
|
272
332
|
rotated_qq);
|
|
273
333
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
334
|
+
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
335
|
+
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d_val);
|
|
336
|
+
query_factors.q_dot_c =
|
|
337
|
+
fvec_inner_product(query, centroid_buf.data(), d_val);
|
|
278
338
|
}
|
|
279
339
|
|
|
280
|
-
|
|
281
|
-
if (ex_bits > 0) {
|
|
340
|
+
if (rabitq.nb_bits > 1) {
|
|
282
341
|
query_factors.rotated_q = rotated_q;
|
|
283
342
|
}
|
|
284
343
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
for (size_t m = 0; m < M; m++) {
|
|
289
|
-
const size_t dim_start = m * 4;
|
|
290
|
-
|
|
291
|
-
for (int code_val = 0; code_val < 16; code_val++) {
|
|
292
|
-
float xor_contribution = 0.0f;
|
|
344
|
+
// Build LUT using branchless subset-sum construction
|
|
345
|
+
const size_t d_sz = d_val;
|
|
293
346
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
if (dim_idx < d) {
|
|
298
|
-
const bool db_bit = (code_val >> dim_offset) & 1;
|
|
299
|
-
const float query_value = rotated_qq[dim_idx];
|
|
300
|
-
|
|
301
|
-
xor_contribution += db_bit
|
|
302
|
-
? (max_code_value - query_value)
|
|
303
|
-
: query_value;
|
|
304
|
-
}
|
|
305
|
-
}
|
|
347
|
+
if (centered_param) {
|
|
348
|
+
const float mcv = static_cast<float>((1 << qb_param) - 1);
|
|
306
349
|
|
|
307
|
-
|
|
350
|
+
for (size_t m = 0; m < M; m++) {
|
|
351
|
+
const size_t ds = m * 4;
|
|
352
|
+
float* out = lut_out + m * 16;
|
|
353
|
+
|
|
354
|
+
float base = 0.0f;
|
|
355
|
+
float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
|
|
356
|
+
if (ds + 0 < d_sz) {
|
|
357
|
+
float q = rotated_qq[ds + 0];
|
|
358
|
+
base += q;
|
|
359
|
+
v0 = mcv - 2.0f * q;
|
|
360
|
+
}
|
|
361
|
+
if (ds + 1 < d_sz) {
|
|
362
|
+
float q = rotated_qq[ds + 1];
|
|
363
|
+
base += q;
|
|
364
|
+
v1 = mcv - 2.0f * q;
|
|
365
|
+
}
|
|
366
|
+
if (ds + 2 < d_sz) {
|
|
367
|
+
float q = rotated_qq[ds + 2];
|
|
368
|
+
base += q;
|
|
369
|
+
v2 = mcv - 2.0f * q;
|
|
370
|
+
}
|
|
371
|
+
if (ds + 3 < d_sz) {
|
|
372
|
+
float q = rotated_qq[ds + 3];
|
|
373
|
+
base += q;
|
|
374
|
+
v3 = mcv - 2.0f * q;
|
|
308
375
|
}
|
|
376
|
+
|
|
377
|
+
write_subset_sum_lut(out, base, v0, v1, v2, v3);
|
|
309
378
|
}
|
|
310
379
|
} else {
|
|
311
|
-
|
|
312
|
-
|
|
380
|
+
const float c1 = query_factors.c1;
|
|
381
|
+
const float c2 = query_factors.c2;
|
|
313
382
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
319
|
-
const size_t dim_idx = dim_start + dim_offset;
|
|
383
|
+
for (size_t m = 0; m < M; m++) {
|
|
384
|
+
const size_t ds = m * 4;
|
|
385
|
+
float* out = lut_out + m * 16;
|
|
320
386
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
387
|
+
float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
|
|
388
|
+
if (ds + 0 < d_sz) {
|
|
389
|
+
v0 = c1 * rotated_qq[ds + 0] + c2;
|
|
390
|
+
}
|
|
391
|
+
if (ds + 1 < d_sz) {
|
|
392
|
+
v1 = c1 * rotated_qq[ds + 1] + c2;
|
|
393
|
+
}
|
|
394
|
+
if (ds + 2 < d_sz) {
|
|
395
|
+
v2 = c1 * rotated_qq[ds + 2] + c2;
|
|
328
396
|
}
|
|
397
|
+
if (ds + 3 < d_sz) {
|
|
398
|
+
v3 = c1 * rotated_qq[ds + 3] + c2;
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
write_subset_sum_lut(out, 0.0f, v0, v1, v2, v3);
|
|
329
402
|
}
|
|
330
403
|
}
|
|
331
404
|
}
|
|
@@ -347,18 +420,27 @@ void IndexIVFRaBitQFastScan::search_preassigned(
|
|
|
347
420
|
!store_pairs, "store_pairs not supported for RaBitQFastScan");
|
|
348
421
|
FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
|
|
349
422
|
|
|
350
|
-
size_t
|
|
423
|
+
size_t cur_nprobe = this->nprobe;
|
|
424
|
+
uint8_t used_qb = qb;
|
|
425
|
+
bool used_centered = centered;
|
|
351
426
|
if (params) {
|
|
352
427
|
FAISS_THROW_IF_NOT(params->max_codes == 0);
|
|
353
|
-
|
|
428
|
+
cur_nprobe = params->nprobe;
|
|
429
|
+
if (auto rparams =
|
|
430
|
+
dynamic_cast<const IVFRaBitQSearchParameters*>(params)) {
|
|
431
|
+
used_qb = rparams->qb;
|
|
432
|
+
used_centered = rparams->centered;
|
|
433
|
+
}
|
|
354
434
|
}
|
|
355
435
|
|
|
356
|
-
std::vector<QueryFactorsData> query_factors_storage(n *
|
|
436
|
+
std::vector<QueryFactorsData> query_factors_storage(n * cur_nprobe);
|
|
357
437
|
FastScanDistancePostProcessing context;
|
|
358
438
|
context.query_factors = query_factors_storage.data();
|
|
359
|
-
context.nprobe =
|
|
439
|
+
context.nprobe = cur_nprobe;
|
|
440
|
+
context.qb = used_qb;
|
|
441
|
+
context.centered = used_centered;
|
|
360
442
|
|
|
361
|
-
const CoarseQuantized cq = {
|
|
443
|
+
const CoarseQuantized cq = {cur_nprobe, centroid_dis, assign};
|
|
362
444
|
search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
|
|
363
445
|
}
|
|
364
446
|
|
|
@@ -372,44 +454,165 @@ void IndexIVFRaBitQFastScan::compute_LUT(
|
|
|
372
454
|
FAISS_THROW_IF_NOT(is_trained);
|
|
373
455
|
FAISS_THROW_IF_NOT(by_residual);
|
|
374
456
|
|
|
375
|
-
|
|
457
|
+
// Use overridden qb/centered from context if provided, else index defaults
|
|
458
|
+
const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
|
|
459
|
+
const bool used_centered = context.qb > 0 ? context.centered : centered;
|
|
460
|
+
|
|
461
|
+
size_t cq_nprobe = cq.nprobe;
|
|
376
462
|
|
|
377
463
|
size_t dim12 = 16 * M;
|
|
378
464
|
|
|
379
|
-
dis_tables.resize(n *
|
|
380
|
-
biases.resize(n *
|
|
465
|
+
dis_tables.resize(n * cq_nprobe * dim12);
|
|
466
|
+
biases.resize(n * cq_nprobe);
|
|
381
467
|
|
|
382
|
-
if (n *
|
|
383
|
-
memset(biases.get(), 0, sizeof(float) * n *
|
|
468
|
+
if (n * cq_nprobe > 0) {
|
|
469
|
+
memset(biases.get(), 0, sizeof(float) * n * cq_nprobe);
|
|
384
470
|
}
|
|
385
|
-
|
|
471
|
+
// Use per-thread buffers instead of one O(n * nprobe * d) allocation.
|
|
472
|
+
// rotated_q / centroid_buf keep their capacity across iterations so the
|
|
473
|
+
// allocator is only hit once per thread.
|
|
474
|
+
#pragma omp parallel if (n * cq_nprobe > 1000)
|
|
475
|
+
{
|
|
476
|
+
std::vector<float> rotated_q(d);
|
|
477
|
+
std::vector<float> centroid_buf(d);
|
|
478
|
+
|
|
479
|
+
#pragma omp for
|
|
480
|
+
for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
|
|
481
|
+
idx_t i = ij / cq_nprobe;
|
|
482
|
+
idx_t cij = cq.ids[ij];
|
|
483
|
+
|
|
484
|
+
if (cij >= 0) {
|
|
485
|
+
QueryFactorsData query_factors_data;
|
|
486
|
+
|
|
487
|
+
compute_residual_LUT(
|
|
488
|
+
x + i * d,
|
|
489
|
+
cij,
|
|
490
|
+
query_factors_data,
|
|
491
|
+
dis_tables.get() + ij * dim12,
|
|
492
|
+
used_qb,
|
|
493
|
+
used_centered,
|
|
494
|
+
rotated_q,
|
|
495
|
+
centroid_buf);
|
|
496
|
+
|
|
497
|
+
if (context.query_factors != nullptr) {
|
|
498
|
+
context.query_factors[ij] = std::move(query_factors_data);
|
|
499
|
+
}
|
|
386
500
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
501
|
+
} else {
|
|
502
|
+
memset(dis_tables.get() + ij * dim12, 0, sizeof(float) * dim12);
|
|
503
|
+
}
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
}
|
|
392
507
|
|
|
393
|
-
|
|
394
|
-
|
|
508
|
+
void IndexIVFRaBitQFastScan::compute_LUT_uint8(
|
|
509
|
+
size_t n,
|
|
510
|
+
const float* x,
|
|
511
|
+
const CoarseQuantized& cq,
|
|
512
|
+
AlignedTable<uint8_t>& dis_tables,
|
|
513
|
+
AlignedTable<uint16_t>& biases,
|
|
514
|
+
float* normalizers,
|
|
515
|
+
const FastScanDistancePostProcessing& context) const {
|
|
516
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
517
|
+
FAISS_THROW_IF_NOT(by_residual);
|
|
395
518
|
|
|
396
|
-
|
|
397
|
-
|
|
519
|
+
const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
|
|
520
|
+
const bool used_centered = context.qb > 0 ? context.centered : centered;
|
|
521
|
+
const size_t cur_nprobe = cq.nprobe;
|
|
522
|
+
const size_t dim12 = 16 * M;
|
|
523
|
+
const size_t dim12_2 = 16 * M2;
|
|
398
524
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
dis_tables.get() + ij * dim12,
|
|
403
|
-
x + i * d);
|
|
525
|
+
// Allocate only the uint8 output table (no full float table)
|
|
526
|
+
dis_tables.resize(n * cur_nprobe * dim12_2);
|
|
527
|
+
biases.resize(n * cur_nprobe);
|
|
404
528
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
529
|
+
#pragma omp parallel if (n > 1)
|
|
530
|
+
{
|
|
531
|
+
// Per-thread buffers reused across queries
|
|
532
|
+
AlignedTable<float> lut_float(cur_nprobe * dim12);
|
|
533
|
+
std::vector<float> rotated_q(d);
|
|
534
|
+
std::vector<float> centroid_buf(d);
|
|
535
|
+
std::vector<float> all_mins(cur_nprobe * M);
|
|
536
|
+
std::vector<float> probe_b(cur_nprobe);
|
|
537
|
+
|
|
538
|
+
#pragma omp for schedule(dynamic)
|
|
539
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
540
|
+
const float* xi = x + i * d;
|
|
541
|
+
|
|
542
|
+
// Compute float LUT for all probes using fused path
|
|
543
|
+
for (size_t j = 0; j < cur_nprobe; j++) {
|
|
544
|
+
const size_t ij = i * cur_nprobe + j;
|
|
545
|
+
idx_t cij = cq.ids[ij];
|
|
546
|
+
|
|
547
|
+
if (cij >= 0) {
|
|
548
|
+
QueryFactorsData qf;
|
|
549
|
+
compute_residual_LUT(
|
|
550
|
+
xi,
|
|
551
|
+
cij,
|
|
552
|
+
qf,
|
|
553
|
+
lut_float.get() + j * dim12,
|
|
554
|
+
used_qb,
|
|
555
|
+
used_centered,
|
|
556
|
+
rotated_q,
|
|
557
|
+
centroid_buf);
|
|
558
|
+
|
|
559
|
+
if (context.query_factors != nullptr) {
|
|
560
|
+
context.query_factors[ij] = qf;
|
|
561
|
+
}
|
|
562
|
+
} else {
|
|
563
|
+
memset(lut_float.get() + j * dim12,
|
|
564
|
+
0,
|
|
565
|
+
sizeof(float) * dim12);
|
|
566
|
+
}
|
|
408
567
|
}
|
|
409
568
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
569
|
+
// Quantize float LUT to uint8 inline.
|
|
570
|
+
// Mirrors quantize_LUT_and_bias 3D path with zero biases.
|
|
571
|
+
// Single pass: find per-sub-q mins, max span, and per-probe b.
|
|
572
|
+
float glob_max_span = -HUGE_VAL;
|
|
573
|
+
float glob_max_dis = -HUGE_VAL;
|
|
574
|
+
float glob_b = HUGE_VAL;
|
|
575
|
+
for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
|
|
576
|
+
float b_j = 0;
|
|
577
|
+
float span_j = 0;
|
|
578
|
+
for (size_t m = 0; m < M; m++) {
|
|
579
|
+
const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
|
|
580
|
+
float mn = tab[0], mx = tab[0];
|
|
581
|
+
for (size_t s = 1; s < ksub; s++) {
|
|
582
|
+
mn = std::min(mn, tab[s]);
|
|
583
|
+
mx = std::max(mx, tab[s]);
|
|
584
|
+
}
|
|
585
|
+
all_mins[j2 * M + m] = mn;
|
|
586
|
+
float span = mx - mn;
|
|
587
|
+
glob_max_span = std::max(glob_max_span, span);
|
|
588
|
+
b_j += mn;
|
|
589
|
+
span_j += span;
|
|
590
|
+
}
|
|
591
|
+
probe_b[j2] = b_j;
|
|
592
|
+
glob_max_dis = std::max(glob_max_dis, span_j);
|
|
593
|
+
glob_b = std::min(glob_b, b_j);
|
|
594
|
+
}
|
|
595
|
+
float a = std::min(255.0f / glob_max_span, 65535.0f / glob_max_dis);
|
|
596
|
+
|
|
597
|
+
// Second pass: quantize LUT and compute biasq
|
|
598
|
+
uint8_t* out_base = dis_tables.get() + i * cur_nprobe * dim12_2;
|
|
599
|
+
uint16_t* bq = biases.get() + i * cur_nprobe;
|
|
600
|
+
for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
|
|
601
|
+
for (size_t m = 0; m < M; m++) {
|
|
602
|
+
const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
|
|
603
|
+
float mn = all_mins[j2 * M + m];
|
|
604
|
+
uint8_t* out = out_base + j2 * dim12_2 + m * ksub;
|
|
605
|
+
for (size_t s = 0; s < ksub; s++) {
|
|
606
|
+
out[s] = static_cast<uint8_t>(
|
|
607
|
+
std::roundf(a * (tab[s] - mn)));
|
|
608
|
+
}
|
|
609
|
+
}
|
|
610
|
+
memset(out_base + j2 * dim12_2 + M * ksub, 0, (M2 - M) * ksub);
|
|
611
|
+
bq[j2] = static_cast<uint16_t>(
|
|
612
|
+
std::roundf(a * (probe_b[j2] - glob_b)));
|
|
613
|
+
}
|
|
614
|
+
normalizers[2 * i] = a;
|
|
615
|
+
normalizers[2 * i + 1] = glob_b;
|
|
413
616
|
}
|
|
414
617
|
}
|
|
415
618
|
}
|
|
@@ -441,23 +644,22 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
|
|
|
441
644
|
}
|
|
442
645
|
}
|
|
443
646
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
}
|
|
647
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
648
|
+
const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
|
|
649
|
+
const size_t full_block_size = get_block_stride();
|
|
650
|
+
|
|
651
|
+
InvertedLists::ScopedCodes list_block_codes(invlists, list_no);
|
|
652
|
+
const uint8_t* aux_ptr = rabitq_utils::get_block_aux_ptr(
|
|
653
|
+
list_block_codes.get(),
|
|
654
|
+
offset,
|
|
655
|
+
bbs,
|
|
656
|
+
packed_block_size,
|
|
657
|
+
full_block_size,
|
|
658
|
+
storage_size);
|
|
659
|
+
|
|
660
|
+
const auto& base_factors =
|
|
661
|
+
*reinterpret_cast<const SignBitFactors*>(aux_ptr);
|
|
662
|
+
const float dp_multiplier = base_factors.dp_multiplier;
|
|
461
663
|
|
|
462
664
|
// Decode residual directly using dp_multiplier
|
|
463
665
|
std::vector<float> residual(d);
|
|
@@ -465,7 +667,7 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
|
|
|
465
667
|
fastscan_code.data(), residual.data(), dp_multiplier);
|
|
466
668
|
|
|
467
669
|
// Reconstruct: x = centroid + residual
|
|
468
|
-
for (size_t j = 0; j < d; j++) {
|
|
670
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
469
671
|
recons[j] = centroid[j] + residual[j];
|
|
470
672
|
}
|
|
471
673
|
}
|
|
@@ -490,7 +692,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
490
692
|
|
|
491
693
|
idx_t list_no = decode_listno(code_i);
|
|
492
694
|
|
|
493
|
-
if (list_no >= 0 && list_no < nlist) {
|
|
695
|
+
if (list_no >= 0 && list_no < static_cast<idx_t>(nlist)) {
|
|
494
696
|
quantizer->reconstruct(list_no, centroid.data());
|
|
495
697
|
|
|
496
698
|
const uint8_t* fastscan_code = code_i + coarse_size;
|
|
@@ -502,7 +704,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
502
704
|
decode_fastscan_to_residual(
|
|
503
705
|
fastscan_code, residual.data(), base_factors.dp_multiplier);
|
|
504
706
|
|
|
505
|
-
for (size_t j = 0; j < d; j++) {
|
|
707
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
506
708
|
x_i[j] = centroid[j] + residual[j];
|
|
507
709
|
}
|
|
508
710
|
} else {
|
|
@@ -519,7 +721,7 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
|
|
|
519
721
|
|
|
520
722
|
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
521
723
|
|
|
522
|
-
for (size_t j = 0; j < d; j++) {
|
|
724
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
523
725
|
bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
|
|
524
726
|
|
|
525
727
|
float bit_as_float = bit_value ? 1.0f : 0.0f;
|
|
@@ -527,302 +729,248 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
|
|
|
527
729
|
}
|
|
528
730
|
}
|
|
529
731
|
|
|
530
|
-
|
|
531
|
-
SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
|
|
732
|
+
std::unique_ptr<FastScanCodeScanner> IndexIVFRaBitQFastScan::make_knn_scanner(
|
|
532
733
|
bool is_max,
|
|
533
|
-
int /* impl */,
|
|
534
734
|
idx_t n,
|
|
535
735
|
idx_t k,
|
|
536
736
|
float* distances,
|
|
537
737
|
idx_t* labels,
|
|
538
|
-
const IDSelector*
|
|
539
|
-
|
|
540
|
-
const
|
|
541
|
-
const
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
if (is_max) {
|
|
545
|
-
return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
|
|
546
|
-
this, n, k, distances, labels, &context, is_multibit);
|
|
547
|
-
} else {
|
|
548
|
-
return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
|
|
549
|
-
this, n, k, distances, labels, &context, is_multibit);
|
|
550
|
-
}
|
|
738
|
+
const IDSelector* sel,
|
|
739
|
+
int /*impl*/,
|
|
740
|
+
const FastScanDistancePostProcessing& context) const {
|
|
741
|
+
const bool is_multibit = (rabitq.nb_bits - 1) > 0;
|
|
742
|
+
return rabitq_ivf_make_knn_scanner(
|
|
743
|
+
is_max, this, n, k, distances, labels, sel, &context, is_multibit);
|
|
551
744
|
}
|
|
552
745
|
|
|
553
746
|
/*********************************************************
|
|
554
|
-
*
|
|
747
|
+
* IVFRaBitQFastScanScanner implementation
|
|
555
748
|
*********************************************************/
|
|
556
749
|
|
|
557
|
-
|
|
558
|
-
IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
|
|
559
|
-
const IndexIVFRaBitQFastScan* idx,
|
|
560
|
-
size_t nq_val,
|
|
561
|
-
size_t k_val,
|
|
562
|
-
float* distances,
|
|
563
|
-
int64_t* labels,
|
|
564
|
-
const FastScanDistancePostProcessing* ctx,
|
|
565
|
-
bool multibit)
|
|
566
|
-
: simd_result_handlers::ResultHandlerCompare<C, true>(
|
|
567
|
-
nq_val,
|
|
568
|
-
0,
|
|
569
|
-
nullptr),
|
|
570
|
-
index(idx),
|
|
571
|
-
heap_distances(distances),
|
|
572
|
-
heap_labels(labels),
|
|
573
|
-
nq(nq_val),
|
|
574
|
-
k(k_val),
|
|
575
|
-
context(ctx),
|
|
576
|
-
is_multibit(multibit) {
|
|
577
|
-
current_list_no = 0;
|
|
578
|
-
probe_indices.clear();
|
|
579
|
-
|
|
580
|
-
// Initialize heaps in constructor (standard pattern from HeapHandler)
|
|
581
|
-
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
582
|
-
float* heap_dis = heap_distances + q * k;
|
|
583
|
-
int64_t* heap_ids = heap_labels + q * k;
|
|
584
|
-
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
|
|
585
|
-
}
|
|
586
|
-
}
|
|
587
|
-
|
|
588
|
-
template <class C>
|
|
589
|
-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
590
|
-
size_t q,
|
|
591
|
-
size_t b,
|
|
592
|
-
simd16uint16 d0,
|
|
593
|
-
simd16uint16 d1) {
|
|
594
|
-
// Store the original local query index before adjust_with_origin changes it
|
|
595
|
-
size_t local_q = q;
|
|
596
|
-
this->adjust_with_origin(q, d0, d1);
|
|
597
|
-
|
|
598
|
-
ALIGNED(32) uint16_t d32tab[32];
|
|
599
|
-
d0.store(d32tab);
|
|
600
|
-
d1.store(d32tab + 16);
|
|
601
|
-
|
|
602
|
-
float* const heap_dis = heap_distances + q * k;
|
|
603
|
-
int64_t* const heap_ids = heap_labels + q * k;
|
|
604
|
-
|
|
605
|
-
FAISS_THROW_IF_NOT_FMT(
|
|
606
|
-
!probe_indices.empty() && local_q < probe_indices.size(),
|
|
607
|
-
"set_list_context() must be called before handle() - probe_indices size: %zu, local_q: %zu, global_q: %zu",
|
|
608
|
-
probe_indices.size(),
|
|
609
|
-
local_q,
|
|
610
|
-
q);
|
|
611
|
-
|
|
612
|
-
// Access query factors directly from array via ProcessingContext
|
|
613
|
-
if (!context || !context->query_factors) {
|
|
614
|
-
FAISS_THROW_MSG(
|
|
615
|
-
"Query factors not available: FastScanDistancePostProcessing with query_factors required");
|
|
616
|
-
}
|
|
750
|
+
namespace {
|
|
617
751
|
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
752
|
+
/// Provides IVF scanner interface using FastScan's SIMD batch processing.
|
|
753
|
+
/// Buffers are allocated once and reused across set_list + scan_codes calls.
|
|
754
|
+
struct IVFRaBitQFastScanScanner : InvertedListScanner {
|
|
755
|
+
using InvertedListScanner::scan_codes;
|
|
756
|
+
static constexpr size_t nq = 1;
|
|
622
757
|
|
|
623
|
-
const
|
|
758
|
+
const IndexIVFRaBitQFastScan& index;
|
|
759
|
+
const uint8_t qb;
|
|
760
|
+
const bool centered;
|
|
624
761
|
|
|
625
|
-
const float
|
|
626
|
-
this->normalizers ? (1.0f / this->normalizers[2 * q]) : 1.0f;
|
|
627
|
-
const float bias = this->normalizers ? this->normalizers[2 * q + 1] : 0.0f;
|
|
762
|
+
const float* xi = nullptr;
|
|
628
763
|
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
764
|
+
// Reusable buffers (allocated once in constructor)
|
|
765
|
+
AlignedTable<uint8_t> dis_tables;
|
|
766
|
+
AlignedTable<uint16_t> biases;
|
|
767
|
+
std::array<float, 2> normalizers{};
|
|
768
|
+
AlignedTable<float> lut_float;
|
|
769
|
+
std::vector<float> rotated_q;
|
|
770
|
+
std::vector<float> centroid_buf;
|
|
771
|
+
QueryFactorsData query_factors;
|
|
772
|
+
FastScanDistancePostProcessing context;
|
|
773
|
+
std::vector<int> probe_map;
|
|
774
|
+
std::vector<float> mins_buf;
|
|
775
|
+
|
|
776
|
+
// Distance computer for distance_to_code (created in set_list)
|
|
777
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc;
|
|
778
|
+
|
|
779
|
+
IVFRaBitQFastScanScanner(
|
|
780
|
+
const IndexIVFRaBitQFastScan& index_in,
|
|
781
|
+
bool store_pairs_in,
|
|
782
|
+
const IDSelector* sel_in,
|
|
783
|
+
uint8_t qb_in,
|
|
784
|
+
bool centered_in)
|
|
785
|
+
: InvertedListScanner(store_pairs_in, sel_in),
|
|
786
|
+
index(index_in),
|
|
787
|
+
qb(qb_in),
|
|
788
|
+
centered(centered_in),
|
|
789
|
+
lut_float(16 * index_in.M),
|
|
790
|
+
rotated_q(index_in.d),
|
|
791
|
+
centroid_buf(index_in.d),
|
|
792
|
+
probe_map({0}),
|
|
793
|
+
mins_buf(index_in.M) {
|
|
794
|
+
this->keep_max = is_similarity_metric(index_in.metric_type);
|
|
795
|
+
this->code_size = index_in.code_size;
|
|
796
|
+
|
|
797
|
+
// Pre-allocate output tables for single probe
|
|
798
|
+
dis_tables.resize(16 * index_in.M2);
|
|
799
|
+
biases.resize(1);
|
|
800
|
+
|
|
801
|
+
// Set up context once
|
|
802
|
+
context.query_factors = &query_factors;
|
|
803
|
+
context.nprobe = 1;
|
|
804
|
+
context.qb = qb;
|
|
805
|
+
context.centered = centered;
|
|
632
806
|
}
|
|
633
807
|
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
638
|
-
// n_multibit_evaluations: candidates requiring full multi-bit distance
|
|
639
|
-
size_t local_1bit_evaluations = 0;
|
|
640
|
-
size_t local_multibit_evaluations = 0;
|
|
808
|
+
void set_query(const float* query) override {
|
|
809
|
+
this->xi = query;
|
|
810
|
+
}
|
|
641
811
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
812
|
+
void set_list(idx_t list_no_in, float /*coarse_dis_in*/) override {
|
|
813
|
+
this->list_no = list_no_in;
|
|
814
|
+
|
|
815
|
+
index.compute_residual_LUT(
|
|
816
|
+
xi,
|
|
817
|
+
list_no_in,
|
|
818
|
+
query_factors,
|
|
819
|
+
lut_float.get(),
|
|
820
|
+
qb,
|
|
821
|
+
centered,
|
|
822
|
+
rotated_q,
|
|
823
|
+
centroid_buf);
|
|
824
|
+
|
|
825
|
+
// Single-probe quantization (simplified inline, no OMP, no 3D)
|
|
826
|
+
const size_t M = index.M;
|
|
827
|
+
const size_t M2 = index.M2;
|
|
828
|
+
const size_t ksub = index.ksub;
|
|
829
|
+
|
|
830
|
+
float max_span = -HUGE_VAL;
|
|
831
|
+
float max_dis = 0;
|
|
832
|
+
float b = 0;
|
|
833
|
+
float* mins = mins_buf.data();
|
|
645
834
|
|
|
646
|
-
|
|
647
|
-
|
|
835
|
+
for (size_t m = 0; m < M; m++) {
|
|
836
|
+
const float* tab = lut_float.get() + m * ksub;
|
|
837
|
+
float mn = tab[0], mx = tab[0];
|
|
838
|
+
for (size_t s = 1; s < ksub; s++) {
|
|
839
|
+
mn = std::min(mn, tab[s]);
|
|
840
|
+
mx = std::max(mx, tab[s]);
|
|
841
|
+
}
|
|
842
|
+
mins[m] = mn;
|
|
843
|
+
float span = mx - mn;
|
|
844
|
+
max_span = std::max(max_span, span);
|
|
845
|
+
max_dis += span;
|
|
846
|
+
b += mn;
|
|
648
847
|
}
|
|
649
848
|
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
if (is_multibit) {
|
|
658
|
-
// Track candidates actually considered for two-stage filtering
|
|
659
|
-
local_1bit_evaluations++;
|
|
660
|
-
|
|
661
|
-
// Multi-bit: use SignBitFactorsWithError and two-stage search
|
|
662
|
-
const SignBitFactorsWithError& full_factors =
|
|
663
|
-
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
664
|
-
|
|
665
|
-
// Compute 1-bit adjusted distance using shared helper
|
|
666
|
-
float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
|
|
667
|
-
normalized_distance,
|
|
668
|
-
full_factors,
|
|
669
|
-
query_factors,
|
|
670
|
-
index->centered,
|
|
671
|
-
index->qb,
|
|
672
|
-
index->d);
|
|
673
|
-
|
|
674
|
-
// Compute lower bound using error bound
|
|
675
|
-
float lower_bound =
|
|
676
|
-
compute_lower_bound(dist_1bit, result_id, local_q, q);
|
|
677
|
-
|
|
678
|
-
// Adaptive filtering: decide whether to compute full distance
|
|
679
|
-
const bool is_similarity =
|
|
680
|
-
index->metric_type == MetricType::METRIC_INNER_PRODUCT;
|
|
681
|
-
bool should_refine = is_similarity
|
|
682
|
-
? (lower_bound > heap_dis[0]) // IP: keep if better
|
|
683
|
-
: (lower_bound < heap_dis[0]); // L2: keep if better
|
|
684
|
-
|
|
685
|
-
if (should_refine) {
|
|
686
|
-
local_multibit_evaluations++;
|
|
687
|
-
|
|
688
|
-
// Compute local_offset: position within current inverted list
|
|
689
|
-
size_t local_offset = this->j0 + b * 32 + j;
|
|
690
|
-
|
|
691
|
-
// Compute full multi-bit distance
|
|
692
|
-
float dist_full = compute_full_multibit_distance(
|
|
693
|
-
result_id, local_q, q, local_offset);
|
|
694
|
-
|
|
695
|
-
// Update heap if this distance is better
|
|
696
|
-
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
697
|
-
heap_replace_top<Cfloat>(
|
|
698
|
-
k, heap_dis, heap_ids, dist_full, result_id);
|
|
699
|
-
}
|
|
700
|
-
}
|
|
701
|
-
} else {
|
|
702
|
-
const auto& db_factors =
|
|
703
|
-
*reinterpret_cast<const SignBitFactors*>(base_ptr);
|
|
704
|
-
|
|
705
|
-
// Compute adjusted distance using shared helper
|
|
706
|
-
float adjusted_distance =
|
|
707
|
-
rabitq_utils::compute_1bit_adjusted_distance(
|
|
708
|
-
normalized_distance,
|
|
709
|
-
db_factors,
|
|
710
|
-
query_factors,
|
|
711
|
-
index->centered,
|
|
712
|
-
index->qb,
|
|
713
|
-
index->d);
|
|
714
|
-
|
|
715
|
-
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
716
|
-
heap_replace_top<Cfloat>(
|
|
717
|
-
k, heap_dis, heap_ids, adjusted_distance, result_id);
|
|
849
|
+
float a = std::min(255.0f / max_span, 65535.0f / max_dis);
|
|
850
|
+
uint8_t* out = dis_tables.get();
|
|
851
|
+
for (size_t m = 0; m < M; m++) {
|
|
852
|
+
const float* tab = lut_float.get() + m * ksub;
|
|
853
|
+
for (size_t s = 0; s < ksub; s++) {
|
|
854
|
+
out[m * ksub + s] = static_cast<uint8_t>(
|
|
855
|
+
std::roundf(a * (tab[s] - mins[m])));
|
|
718
856
|
}
|
|
719
857
|
}
|
|
858
|
+
memset(out + M * ksub, 0, (M2 - M) * ksub);
|
|
859
|
+
biases[0] = 0;
|
|
860
|
+
normalizers[0] = a;
|
|
861
|
+
normalizers[1] = b;
|
|
862
|
+
|
|
863
|
+
// Create distance computer (reuses centroid_buf from
|
|
864
|
+
// compute_residual_LUT)
|
|
865
|
+
dc.reset(index.rabitq.get_distance_computer(
|
|
866
|
+
qb, centroid_buf.data(), centered));
|
|
867
|
+
dc->set_query(xi);
|
|
720
868
|
}
|
|
721
869
|
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
#pragma omp atomic
|
|
726
|
-
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
|
|
727
|
-
}
|
|
870
|
+
float distance_to_code(const uint8_t* code) const override {
|
|
871
|
+
return dc->distance_to_code(code);
|
|
872
|
+
}
|
|
728
873
|
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
874
|
+
size_t scan_codes(
|
|
875
|
+
size_t ntotal,
|
|
876
|
+
const uint8_t* codes,
|
|
877
|
+
const idx_t* ids,
|
|
878
|
+
ResultHandler& result_handler) const override {
|
|
879
|
+
auto scan_with_heap = [&](auto* heap_handler) -> size_t {
|
|
880
|
+
const size_t k = heap_handler->k;
|
|
881
|
+
if (k == 0) {
|
|
882
|
+
return 0;
|
|
883
|
+
}
|
|
736
884
|
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
885
|
+
std::vector<float> curr_dists(k, result_handler.threshold);
|
|
886
|
+
std::vector<idx_t> curr_labels(k, -1);
|
|
887
|
+
|
|
888
|
+
auto scanner = index.make_knn_scanner(
|
|
889
|
+
!keep_max,
|
|
890
|
+
nq,
|
|
891
|
+
k,
|
|
892
|
+
curr_dists.data(),
|
|
893
|
+
curr_labels.data(),
|
|
894
|
+
sel,
|
|
895
|
+
0,
|
|
896
|
+
context);
|
|
897
|
+
auto* handler = scanner->handler();
|
|
898
|
+
|
|
899
|
+
int qmap1[1] = {0};
|
|
900
|
+
handler->q_map = qmap1;
|
|
901
|
+
handler->begin(&normalizers[0]);
|
|
902
|
+
handler->dbias = biases.get();
|
|
903
|
+
handler->ntotal = ntotal;
|
|
904
|
+
handler->id_map = ids;
|
|
905
|
+
|
|
906
|
+
handler->set_list_context(list_no, probe_map);
|
|
907
|
+
if (!handler->list_codes_ptr) {
|
|
908
|
+
handler->list_codes_ptr = codes;
|
|
909
|
+
}
|
|
742
910
|
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
911
|
+
scanner->accumulate_loop(
|
|
912
|
+
1,
|
|
913
|
+
roundup(ntotal, index.bbs),
|
|
914
|
+
index.bbs,
|
|
915
|
+
static_cast<int>(index.M2),
|
|
916
|
+
codes,
|
|
917
|
+
dis_tables.get(),
|
|
918
|
+
0,
|
|
919
|
+
index.get_block_stride());
|
|
920
|
+
|
|
921
|
+
const size_t scan_cnt = handler->count_scanned_rows();
|
|
922
|
+
handler->end();
|
|
923
|
+
|
|
924
|
+
result_handler.stats.scan_cnt += scan_cnt;
|
|
925
|
+
size_t nup = 0;
|
|
926
|
+
for (size_t j = 0; j < k; j++) {
|
|
927
|
+
if (curr_labels[j] < 0) {
|
|
928
|
+
continue;
|
|
929
|
+
}
|
|
930
|
+
if (result_handler.add_result(curr_dists[j], curr_labels[j])) {
|
|
931
|
+
result_handler.stats.nheap_updates++;
|
|
932
|
+
nup++;
|
|
933
|
+
}
|
|
934
|
+
}
|
|
935
|
+
return nup;
|
|
936
|
+
};
|
|
937
|
+
|
|
938
|
+
if (!keep_max) {
|
|
939
|
+
using C = CMax<float, idx_t>;
|
|
940
|
+
if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
|
|
941
|
+
&result_handler)) {
|
|
942
|
+
return scan_with_heap(heap_handler);
|
|
943
|
+
}
|
|
944
|
+
} else {
|
|
945
|
+
using C = CMin<float, idx_t>;
|
|
946
|
+
if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
|
|
947
|
+
&result_handler)) {
|
|
948
|
+
return scan_with_heap(heap_handler);
|
|
949
|
+
}
|
|
950
|
+
}
|
|
752
951
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
size_t local_q,
|
|
758
|
-
size_t global_q) const {
|
|
759
|
-
// Access f_error from SignBitFactorsWithError in flat storage
|
|
760
|
-
const size_t storage_size = index->compute_per_vector_storage_size();
|
|
761
|
-
const uint8_t* base_ptr =
|
|
762
|
-
index->flat_storage.data() + db_idx * storage_size;
|
|
763
|
-
const SignBitFactorsWithError& db_factors =
|
|
764
|
-
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
765
|
-
float f_error = db_factors.f_error;
|
|
766
|
-
|
|
767
|
-
// Get g_error from query factors
|
|
768
|
-
// Use local_q to access probe_indices (batch-local), global_q for storage
|
|
769
|
-
float g_error = 0.0f;
|
|
770
|
-
if (context && context->query_factors) {
|
|
771
|
-
size_t probe_rank = probe_indices[local_q];
|
|
772
|
-
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
|
|
773
|
-
size_t storage_idx = global_q * nprobe + probe_rank;
|
|
774
|
-
g_error = context->query_factors[storage_idx].g_error;
|
|
952
|
+
FAISS_THROW_MSG(
|
|
953
|
+
"IVFRaBitQFastScanScanner::scan_codes requires "
|
|
954
|
+
"HeapResultHandler; custom ResultHandler scan is not supported "
|
|
955
|
+
"by this optimized scanner");
|
|
775
956
|
}
|
|
957
|
+
};
|
|
776
958
|
|
|
777
|
-
|
|
778
|
-
float error_adjustment = f_error * g_error;
|
|
959
|
+
} // anonymous namespace
|
|
779
960
|
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
const size_t storage_size = index->compute_per_vector_storage_size();
|
|
794
|
-
const uint8_t* base_ptr =
|
|
795
|
-
index->flat_storage.data() + db_idx * storage_size;
|
|
796
|
-
|
|
797
|
-
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
798
|
-
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
799
|
-
const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
|
|
800
|
-
base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
|
|
801
|
-
|
|
802
|
-
// Use local_q to access probe_indices (batch-local), global_q for storage
|
|
803
|
-
size_t probe_rank = probe_indices[local_q];
|
|
804
|
-
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
|
|
805
|
-
size_t storage_idx = global_q * nprobe + probe_rank;
|
|
806
|
-
const auto& query_factors = context->query_factors[storage_idx];
|
|
807
|
-
|
|
808
|
-
size_t list_no = current_list_no;
|
|
809
|
-
InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
|
|
810
|
-
|
|
811
|
-
std::vector<uint8_t> unpacked_code(index->code_size);
|
|
812
|
-
CodePackerPQ4 packer(index->M2, index->bbs);
|
|
813
|
-
packer.unpack_1(list_codes.get(), local_offset, unpacked_code.data());
|
|
814
|
-
const uint8_t* sign_bits = unpacked_code.data();
|
|
815
|
-
|
|
816
|
-
return rabitq_utils::compute_full_multibit_distance(
|
|
817
|
-
sign_bits,
|
|
818
|
-
ex_code,
|
|
819
|
-
ex_fac,
|
|
820
|
-
query_factors.rotated_q.data(),
|
|
821
|
-
query_factors.qr_to_c_L2sqr,
|
|
822
|
-
query_factors.qr_norm_L2sqr,
|
|
823
|
-
dim,
|
|
824
|
-
ex_bits,
|
|
825
|
-
index->metric_type);
|
|
961
|
+
InvertedListScanner* IndexIVFRaBitQFastScan::get_InvertedListScanner(
|
|
962
|
+
bool store_pairs,
|
|
963
|
+
const IDSelector* sel,
|
|
964
|
+
const IVFSearchParameters* search_params_in) const {
|
|
965
|
+
uint8_t used_qb = qb;
|
|
966
|
+
bool used_centered = centered;
|
|
967
|
+
if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
|
|
968
|
+
search_params_in)) {
|
|
969
|
+
used_qb = params->qb;
|
|
970
|
+
used_centered = params->centered;
|
|
971
|
+
}
|
|
972
|
+
return new IVFRaBitQFastScanScanner(
|
|
973
|
+
*this, store_pairs, sel, used_qb, used_centered);
|
|
826
974
|
}
|
|
827
975
|
|
|
828
976
|
} // namespace faiss
|