faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -14,11 +14,11 @@
|
|
|
14
14
|
|
|
15
15
|
#include <faiss/impl/CodePackerRaBitQ.h>
|
|
16
16
|
#include <faiss/impl/FaissAssert.h>
|
|
17
|
-
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
18
17
|
#include <faiss/impl/RaBitQUtils.h>
|
|
19
18
|
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
20
|
-
#include <faiss/impl/
|
|
21
|
-
#include <faiss/impl/
|
|
19
|
+
#include <faiss/impl/ResultHandler.h>
|
|
20
|
+
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
|
|
21
|
+
#include <faiss/impl/fast_scan/fast_scan.h>
|
|
22
22
|
#include <faiss/invlists/BlockInvertedLists.h>
|
|
23
23
|
#include <faiss/utils/distances.h>
|
|
24
24
|
#include <faiss/utils/utils.h>
|
|
@@ -42,31 +42,38 @@ inline size_t roundup(size_t a, size_t b) {
|
|
|
42
42
|
IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan() = default;
|
|
43
43
|
|
|
44
44
|
IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
45
|
-
Index*
|
|
46
|
-
size_t
|
|
47
|
-
size_t
|
|
45
|
+
Index* quantizer_in,
|
|
46
|
+
size_t d_in,
|
|
47
|
+
size_t nlist_in,
|
|
48
48
|
MetricType metric,
|
|
49
|
-
int
|
|
50
|
-
bool
|
|
49
|
+
int bbs_in,
|
|
50
|
+
bool own_invlists_in,
|
|
51
51
|
uint8_t nb_bits)
|
|
52
|
-
: IndexIVFFastScan(
|
|
53
|
-
|
|
54
|
-
|
|
52
|
+
: IndexIVFFastScan(
|
|
53
|
+
quantizer_in,
|
|
54
|
+
d_in,
|
|
55
|
+
nlist_in,
|
|
56
|
+
0,
|
|
57
|
+
metric,
|
|
58
|
+
own_invlists_in),
|
|
59
|
+
rabitq(d_in, metric, nb_bits) {
|
|
60
|
+
FAISS_THROW_IF_NOT_MSG(d_in > 0, "Dimension must be positive");
|
|
55
61
|
FAISS_THROW_IF_NOT_MSG(
|
|
56
62
|
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
|
|
57
63
|
"RaBitQ only supports L2 and Inner Product metrics");
|
|
58
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
59
|
-
|
|
64
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
65
|
+
bbs_in % 32 == 0, "Batch size must be multiple of 32");
|
|
66
|
+
FAISS_THROW_IF_NOT_MSG(quantizer_in != nullptr, "Quantizer cannot be null");
|
|
60
67
|
|
|
61
68
|
by_residual = true;
|
|
62
69
|
qb = 8; // RaBitQ quantization bits
|
|
63
70
|
centered = false;
|
|
64
71
|
|
|
65
72
|
// FastScan-specific parameters: 4 bits per sub-quantizer
|
|
66
|
-
const size_t M_fastscan = (
|
|
73
|
+
const size_t M_fastscan = (d_in + 3) / 4;
|
|
67
74
|
constexpr size_t nbits_fastscan = 4;
|
|
68
75
|
|
|
69
|
-
this->bbs =
|
|
76
|
+
this->bbs = bbs_in;
|
|
70
77
|
this->fine_quantizer = &rabitq;
|
|
71
78
|
this->M = M_fastscan;
|
|
72
79
|
this->nbits = nbits_fastscan;
|
|
@@ -101,6 +108,10 @@ size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
|
101
108
|
return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
|
|
102
109
|
}
|
|
103
110
|
|
|
111
|
+
size_t IndexIVFRaBitQFastScan::fast_scan_code_size() const {
|
|
112
|
+
return (d + 7) / 8;
|
|
113
|
+
}
|
|
114
|
+
|
|
104
115
|
size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
|
|
105
116
|
// Use code_size as stride to skip embedded factor data during packing
|
|
106
117
|
return code_size;
|
|
@@ -195,7 +206,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
|
|
|
195
206
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
196
207
|
|
|
197
208
|
// Pack sign bits directly into FastScan format (inline)
|
|
198
|
-
for (size_t j = 0; j < d; j++) {
|
|
209
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
199
210
|
const float or_minus_c = xi[j] - centroid[j];
|
|
200
211
|
if (or_minus_c > 0.0f) {
|
|
201
212
|
rabitq_utils::set_bit_fastscan(fastscan_code, j);
|
|
@@ -224,7 +235,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
|
|
|
224
235
|
|
|
225
236
|
// Compute residual (needed for quantize_ex_bits)
|
|
226
237
|
std::vector<float> residual(d);
|
|
227
|
-
for (size_t j = 0; j < d; j++) {
|
|
238
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
228
239
|
residual[j] = xi[j] - centroid[j];
|
|
229
240
|
}
|
|
230
241
|
|
|
@@ -261,84 +272,133 @@ bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
|
|
|
261
272
|
return true;
|
|
262
273
|
}
|
|
263
274
|
|
|
275
|
+
// out[code] = base + sum of v_i for each set bit in code.
|
|
276
|
+
inline void write_subset_sum_lut(
|
|
277
|
+
float* out,
|
|
278
|
+
float base,
|
|
279
|
+
float v0,
|
|
280
|
+
float v1,
|
|
281
|
+
float v2,
|
|
282
|
+
float v3) {
|
|
283
|
+
out[0] = base;
|
|
284
|
+
out[1] = base + v0;
|
|
285
|
+
out[2] = base + v1;
|
|
286
|
+
out[3] = base + v0 + v1;
|
|
287
|
+
out[4] = base + v2;
|
|
288
|
+
out[5] = base + v0 + v2;
|
|
289
|
+
out[6] = base + v1 + v2;
|
|
290
|
+
out[7] = base + v0 + v1 + v2;
|
|
291
|
+
out[8] = base + v3;
|
|
292
|
+
out[9] = base + v0 + v3;
|
|
293
|
+
out[10] = base + v1 + v3;
|
|
294
|
+
out[11] = base + v0 + v1 + v3;
|
|
295
|
+
out[12] = base + v2 + v3;
|
|
296
|
+
out[13] = base + v0 + v2 + v3;
|
|
297
|
+
out[14] = base + v1 + v2 + v3;
|
|
298
|
+
out[15] = base + v0 + v1 + v2 + v3;
|
|
299
|
+
}
|
|
300
|
+
|
|
264
301
|
// Computes lookup table for residual vectors in RaBitQ FastScan format
|
|
265
302
|
void IndexIVFRaBitQFastScan::compute_residual_LUT(
|
|
266
|
-
const float*
|
|
303
|
+
const float* query,
|
|
304
|
+
idx_t centroid_id,
|
|
267
305
|
QueryFactorsData& query_factors,
|
|
268
306
|
float* lut_out,
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
307
|
+
uint8_t qb_param,
|
|
308
|
+
bool centered_param,
|
|
309
|
+
std::vector<float>& rotated_q,
|
|
310
|
+
std::vector<float>& centroid_buf) const {
|
|
311
|
+
const size_t d_val = static_cast<size_t>(d);
|
|
312
|
+
FAISS_THROW_IF_NOT(d_val > 0);
|
|
313
|
+
rotated_q.resize(d_val);
|
|
314
|
+
centroid_buf.resize(d_val);
|
|
315
|
+
std::vector<uint8_t> rotated_qq(d_val);
|
|
316
|
+
|
|
317
|
+
// Compute residual
|
|
318
|
+
quantizer->reconstruct(centroid_id, centroid_buf.data());
|
|
319
|
+
for (size_t i = 0; i < d_val; i++) {
|
|
320
|
+
rotated_q[i] = query[i] - centroid_buf[i];
|
|
321
|
+
}
|
|
274
322
|
|
|
275
|
-
//
|
|
323
|
+
// Compute query factors using shared utility
|
|
276
324
|
query_factors = rabitq_utils::compute_query_factors(
|
|
277
|
-
|
|
278
|
-
|
|
325
|
+
rotated_q.data(),
|
|
326
|
+
d_val,
|
|
279
327
|
nullptr,
|
|
280
|
-
|
|
281
|
-
|
|
328
|
+
qb_param,
|
|
329
|
+
centered_param,
|
|
282
330
|
metric_type,
|
|
283
331
|
rotated_q,
|
|
284
332
|
rotated_qq);
|
|
285
333
|
|
|
286
|
-
if (metric_type == MetricType::METRIC_INNER_PRODUCT
|
|
287
|
-
|
|
288
|
-
query_factors.
|
|
289
|
-
|
|
290
|
-
fvec_inner_product(original_query, residual, d);
|
|
334
|
+
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
335
|
+
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d_val);
|
|
336
|
+
query_factors.q_dot_c =
|
|
337
|
+
fvec_inner_product(query, centroid_buf.data(), d_val);
|
|
291
338
|
}
|
|
292
339
|
|
|
293
|
-
|
|
294
|
-
if (ex_bits > 0) {
|
|
340
|
+
if (rabitq.nb_bits > 1) {
|
|
295
341
|
query_factors.rotated_q = rotated_q;
|
|
296
342
|
}
|
|
297
343
|
|
|
298
|
-
|
|
299
|
-
|
|
344
|
+
// Build LUT using branchless subset-sum construction
|
|
345
|
+
const size_t d_sz = d_val;
|
|
300
346
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
for (int code_val = 0; code_val < 16; code_val++) {
|
|
305
|
-
float xor_contribution = 0.0f;
|
|
306
|
-
|
|
307
|
-
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
308
|
-
const size_t dim_idx = dim_start + dim_offset;
|
|
347
|
+
if (centered_param) {
|
|
348
|
+
const float mcv = static_cast<float>((1 << qb_param) - 1);
|
|
309
349
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
350
|
+
for (size_t m = 0; m < M; m++) {
|
|
351
|
+
const size_t ds = m * 4;
|
|
352
|
+
float* out = lut_out + m * 16;
|
|
353
|
+
|
|
354
|
+
float base = 0.0f;
|
|
355
|
+
float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
|
|
356
|
+
if (ds + 0 < d_sz) {
|
|
357
|
+
float q = rotated_qq[ds + 0];
|
|
358
|
+
base += q;
|
|
359
|
+
v0 = mcv - 2.0f * q;
|
|
360
|
+
}
|
|
361
|
+
if (ds + 1 < d_sz) {
|
|
362
|
+
float q = rotated_qq[ds + 1];
|
|
363
|
+
base += q;
|
|
364
|
+
v1 = mcv - 2.0f * q;
|
|
321
365
|
}
|
|
366
|
+
if (ds + 2 < d_sz) {
|
|
367
|
+
float q = rotated_qq[ds + 2];
|
|
368
|
+
base += q;
|
|
369
|
+
v2 = mcv - 2.0f * q;
|
|
370
|
+
}
|
|
371
|
+
if (ds + 3 < d_sz) {
|
|
372
|
+
float q = rotated_qq[ds + 3];
|
|
373
|
+
base += q;
|
|
374
|
+
v3 = mcv - 2.0f * q;
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
write_subset_sum_lut(out, base, v0, v1, v2, v3);
|
|
322
378
|
}
|
|
323
379
|
} else {
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
for (int code_val = 0; code_val < 16; code_val++) {
|
|
328
|
-
float inner_product = 0.0f;
|
|
329
|
-
int popcount = 0;
|
|
380
|
+
const float c1 = query_factors.c1;
|
|
381
|
+
const float c2 = query_factors.c2;
|
|
330
382
|
|
|
331
|
-
|
|
332
|
-
|
|
383
|
+
for (size_t m = 0; m < M; m++) {
|
|
384
|
+
const size_t ds = m * 4;
|
|
385
|
+
float* out = lut_out + m * 16;
|
|
333
386
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
query_factors.c2 * popcount;
|
|
387
|
+
float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
|
|
388
|
+
if (ds + 0 < d_sz) {
|
|
389
|
+
v0 = c1 * rotated_qq[ds + 0] + c2;
|
|
390
|
+
}
|
|
391
|
+
if (ds + 1 < d_sz) {
|
|
392
|
+
v1 = c1 * rotated_qq[ds + 1] + c2;
|
|
341
393
|
}
|
|
394
|
+
if (ds + 2 < d_sz) {
|
|
395
|
+
v2 = c1 * rotated_qq[ds + 2] + c2;
|
|
396
|
+
}
|
|
397
|
+
if (ds + 3 < d_sz) {
|
|
398
|
+
v3 = c1 * rotated_qq[ds + 3] + c2;
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
write_subset_sum_lut(out, 0.0f, v0, v1, v2, v3);
|
|
342
402
|
}
|
|
343
403
|
}
|
|
344
404
|
}
|
|
@@ -360,18 +420,27 @@ void IndexIVFRaBitQFastScan::search_preassigned(
|
|
|
360
420
|
!store_pairs, "store_pairs not supported for RaBitQFastScan");
|
|
361
421
|
FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
|
|
362
422
|
|
|
363
|
-
size_t
|
|
423
|
+
size_t cur_nprobe = this->nprobe;
|
|
424
|
+
uint8_t used_qb = qb;
|
|
425
|
+
bool used_centered = centered;
|
|
364
426
|
if (params) {
|
|
365
427
|
FAISS_THROW_IF_NOT(params->max_codes == 0);
|
|
366
|
-
|
|
428
|
+
cur_nprobe = params->nprobe;
|
|
429
|
+
if (auto rparams =
|
|
430
|
+
dynamic_cast<const IVFRaBitQSearchParameters*>(params)) {
|
|
431
|
+
used_qb = rparams->qb;
|
|
432
|
+
used_centered = rparams->centered;
|
|
433
|
+
}
|
|
367
434
|
}
|
|
368
435
|
|
|
369
|
-
std::vector<QueryFactorsData> query_factors_storage(n *
|
|
436
|
+
std::vector<QueryFactorsData> query_factors_storage(n * cur_nprobe);
|
|
370
437
|
FastScanDistancePostProcessing context;
|
|
371
438
|
context.query_factors = query_factors_storage.data();
|
|
372
|
-
context.nprobe =
|
|
439
|
+
context.nprobe = cur_nprobe;
|
|
440
|
+
context.qb = used_qb;
|
|
441
|
+
context.centered = used_centered;
|
|
373
442
|
|
|
374
|
-
const CoarseQuantized cq = {
|
|
443
|
+
const CoarseQuantized cq = {cur_nprobe, centroid_dis, assign};
|
|
375
444
|
search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
|
|
376
445
|
}
|
|
377
446
|
|
|
@@ -385,44 +454,165 @@ void IndexIVFRaBitQFastScan::compute_LUT(
|
|
|
385
454
|
FAISS_THROW_IF_NOT(is_trained);
|
|
386
455
|
FAISS_THROW_IF_NOT(by_residual);
|
|
387
456
|
|
|
388
|
-
|
|
457
|
+
// Use overridden qb/centered from context if provided, else index defaults
|
|
458
|
+
const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
|
|
459
|
+
const bool used_centered = context.qb > 0 ? context.centered : centered;
|
|
460
|
+
|
|
461
|
+
size_t cq_nprobe = cq.nprobe;
|
|
389
462
|
|
|
390
463
|
size_t dim12 = 16 * M;
|
|
391
464
|
|
|
392
|
-
dis_tables.resize(n *
|
|
393
|
-
biases.resize(n *
|
|
465
|
+
dis_tables.resize(n * cq_nprobe * dim12);
|
|
466
|
+
biases.resize(n * cq_nprobe);
|
|
394
467
|
|
|
395
|
-
if (n *
|
|
396
|
-
memset(biases.get(), 0, sizeof(float) * n *
|
|
468
|
+
if (n * cq_nprobe > 0) {
|
|
469
|
+
memset(biases.get(), 0, sizeof(float) * n * cq_nprobe);
|
|
397
470
|
}
|
|
398
|
-
|
|
471
|
+
// Use per-thread buffers instead of one O(n * nprobe * d) allocation.
|
|
472
|
+
// rotated_q / centroid_buf keep their capacity across iterations so the
|
|
473
|
+
// allocator is only hit once per thread.
|
|
474
|
+
#pragma omp parallel if (n * cq_nprobe > 1000)
|
|
475
|
+
{
|
|
476
|
+
std::vector<float> rotated_q(d);
|
|
477
|
+
std::vector<float> centroid_buf(d);
|
|
478
|
+
|
|
479
|
+
#pragma omp for
|
|
480
|
+
for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
|
|
481
|
+
idx_t i = ij / cq_nprobe;
|
|
482
|
+
idx_t cij = cq.ids[ij];
|
|
483
|
+
|
|
484
|
+
if (cij >= 0) {
|
|
485
|
+
QueryFactorsData query_factors_data;
|
|
486
|
+
|
|
487
|
+
compute_residual_LUT(
|
|
488
|
+
x + i * d,
|
|
489
|
+
cij,
|
|
490
|
+
query_factors_data,
|
|
491
|
+
dis_tables.get() + ij * dim12,
|
|
492
|
+
used_qb,
|
|
493
|
+
used_centered,
|
|
494
|
+
rotated_q,
|
|
495
|
+
centroid_buf);
|
|
496
|
+
|
|
497
|
+
if (context.query_factors != nullptr) {
|
|
498
|
+
context.query_factors[ij] = std::move(query_factors_data);
|
|
499
|
+
}
|
|
399
500
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
501
|
+
} else {
|
|
502
|
+
memset(dis_tables.get() + ij * dim12, 0, sizeof(float) * dim12);
|
|
503
|
+
}
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
}
|
|
405
507
|
|
|
406
|
-
|
|
407
|
-
|
|
508
|
+
void IndexIVFRaBitQFastScan::compute_LUT_uint8(
|
|
509
|
+
size_t n,
|
|
510
|
+
const float* x,
|
|
511
|
+
const CoarseQuantized& cq,
|
|
512
|
+
AlignedTable<uint8_t>& dis_tables,
|
|
513
|
+
AlignedTable<uint16_t>& biases,
|
|
514
|
+
float* normalizers,
|
|
515
|
+
const FastScanDistancePostProcessing& context) const {
|
|
516
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
517
|
+
FAISS_THROW_IF_NOT(by_residual);
|
|
408
518
|
|
|
409
|
-
|
|
410
|
-
|
|
519
|
+
const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
|
|
520
|
+
const bool used_centered = context.qb > 0 ? context.centered : centered;
|
|
521
|
+
const size_t cur_nprobe = cq.nprobe;
|
|
522
|
+
const size_t dim12 = 16 * M;
|
|
523
|
+
const size_t dim12_2 = 16 * M2;
|
|
411
524
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
dis_tables.get() + ij * dim12,
|
|
416
|
-
x + i * d);
|
|
525
|
+
// Allocate only the uint8 output table (no full float table)
|
|
526
|
+
dis_tables.resize(n * cur_nprobe * dim12_2);
|
|
527
|
+
biases.resize(n * cur_nprobe);
|
|
417
528
|
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
529
|
+
#pragma omp parallel if (n > 1)
|
|
530
|
+
{
|
|
531
|
+
// Per-thread buffers reused across queries
|
|
532
|
+
AlignedTable<float> lut_float(cur_nprobe * dim12);
|
|
533
|
+
std::vector<float> rotated_q(d);
|
|
534
|
+
std::vector<float> centroid_buf(d);
|
|
535
|
+
std::vector<float> all_mins(cur_nprobe * M);
|
|
536
|
+
std::vector<float> probe_b(cur_nprobe);
|
|
537
|
+
|
|
538
|
+
#pragma omp for schedule(dynamic)
|
|
539
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
540
|
+
const float* xi = x + i * d;
|
|
541
|
+
|
|
542
|
+
// Compute float LUT for all probes using fused path
|
|
543
|
+
for (size_t j = 0; j < cur_nprobe; j++) {
|
|
544
|
+
const size_t ij = i * cur_nprobe + j;
|
|
545
|
+
idx_t cij = cq.ids[ij];
|
|
546
|
+
|
|
547
|
+
if (cij >= 0) {
|
|
548
|
+
QueryFactorsData qf;
|
|
549
|
+
compute_residual_LUT(
|
|
550
|
+
xi,
|
|
551
|
+
cij,
|
|
552
|
+
qf,
|
|
553
|
+
lut_float.get() + j * dim12,
|
|
554
|
+
used_qb,
|
|
555
|
+
used_centered,
|
|
556
|
+
rotated_q,
|
|
557
|
+
centroid_buf);
|
|
558
|
+
|
|
559
|
+
if (context.query_factors != nullptr) {
|
|
560
|
+
context.query_factors[ij] = qf;
|
|
561
|
+
}
|
|
562
|
+
} else {
|
|
563
|
+
memset(lut_float.get() + j * dim12,
|
|
564
|
+
0,
|
|
565
|
+
sizeof(float) * dim12);
|
|
566
|
+
}
|
|
421
567
|
}
|
|
422
568
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
569
|
+
// Quantize float LUT to uint8 inline.
|
|
570
|
+
// Mirrors quantize_LUT_and_bias 3D path with zero biases.
|
|
571
|
+
// Single pass: find per-sub-q mins, max span, and per-probe b.
|
|
572
|
+
float glob_max_span = -HUGE_VAL;
|
|
573
|
+
float glob_max_dis = -HUGE_VAL;
|
|
574
|
+
float glob_b = HUGE_VAL;
|
|
575
|
+
for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
|
|
576
|
+
float b_j = 0;
|
|
577
|
+
float span_j = 0;
|
|
578
|
+
for (size_t m = 0; m < M; m++) {
|
|
579
|
+
const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
|
|
580
|
+
float mn = tab[0], mx = tab[0];
|
|
581
|
+
for (size_t s = 1; s < ksub; s++) {
|
|
582
|
+
mn = std::min(mn, tab[s]);
|
|
583
|
+
mx = std::max(mx, tab[s]);
|
|
584
|
+
}
|
|
585
|
+
all_mins[j2 * M + m] = mn;
|
|
586
|
+
float span = mx - mn;
|
|
587
|
+
glob_max_span = std::max(glob_max_span, span);
|
|
588
|
+
b_j += mn;
|
|
589
|
+
span_j += span;
|
|
590
|
+
}
|
|
591
|
+
probe_b[j2] = b_j;
|
|
592
|
+
glob_max_dis = std::max(glob_max_dis, span_j);
|
|
593
|
+
glob_b = std::min(glob_b, b_j);
|
|
594
|
+
}
|
|
595
|
+
float a = std::min(255.0f / glob_max_span, 65535.0f / glob_max_dis);
|
|
596
|
+
|
|
597
|
+
// Second pass: quantize LUT and compute biasq
|
|
598
|
+
uint8_t* out_base = dis_tables.get() + i * cur_nprobe * dim12_2;
|
|
599
|
+
uint16_t* bq = biases.get() + i * cur_nprobe;
|
|
600
|
+
for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
|
|
601
|
+
for (size_t m = 0; m < M; m++) {
|
|
602
|
+
const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
|
|
603
|
+
float mn = all_mins[j2 * M + m];
|
|
604
|
+
uint8_t* out = out_base + j2 * dim12_2 + m * ksub;
|
|
605
|
+
for (size_t s = 0; s < ksub; s++) {
|
|
606
|
+
out[s] = static_cast<uint8_t>(
|
|
607
|
+
std::roundf(a * (tab[s] - mn)));
|
|
608
|
+
}
|
|
609
|
+
}
|
|
610
|
+
memset(out_base + j2 * dim12_2 + M * ksub, 0, (M2 - M) * ksub);
|
|
611
|
+
bq[j2] = static_cast<uint16_t>(
|
|
612
|
+
std::roundf(a * (probe_b[j2] - glob_b)));
|
|
613
|
+
}
|
|
614
|
+
normalizers[2 * i] = a;
|
|
615
|
+
normalizers[2 * i + 1] = glob_b;
|
|
426
616
|
}
|
|
427
617
|
}
|
|
428
618
|
}
|
|
@@ -477,7 +667,7 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
|
|
|
477
667
|
fastscan_code.data(), residual.data(), dp_multiplier);
|
|
478
668
|
|
|
479
669
|
// Reconstruct: x = centroid + residual
|
|
480
|
-
for (size_t j = 0; j < d; j++) {
|
|
670
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
481
671
|
recons[j] = centroid[j] + residual[j];
|
|
482
672
|
}
|
|
483
673
|
}
|
|
@@ -502,7 +692,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
502
692
|
|
|
503
693
|
idx_t list_no = decode_listno(code_i);
|
|
504
694
|
|
|
505
|
-
if (list_no >= 0 && list_no < nlist) {
|
|
695
|
+
if (list_no >= 0 && list_no < static_cast<idx_t>(nlist)) {
|
|
506
696
|
quantizer->reconstruct(list_no, centroid.data());
|
|
507
697
|
|
|
508
698
|
const uint8_t* fastscan_code = code_i + coarse_size;
|
|
@@ -514,7 +704,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
514
704
|
decode_fastscan_to_residual(
|
|
515
705
|
fastscan_code, residual.data(), base_factors.dp_multiplier);
|
|
516
706
|
|
|
517
|
-
for (size_t j = 0; j < d; j++) {
|
|
707
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
518
708
|
x_i[j] = centroid[j] + residual[j];
|
|
519
709
|
}
|
|
520
710
|
} else {
|
|
@@ -531,7 +721,7 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
|
|
|
531
721
|
|
|
532
722
|
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
533
723
|
|
|
534
|
-
for (size_t j = 0; j < d; j++) {
|
|
724
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
535
725
|
bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
|
|
536
726
|
|
|
537
727
|
float bit_as_float = bit_value ? 1.0f : 0.0f;
|
|
@@ -539,287 +729,18 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
|
|
|
539
729
|
}
|
|
540
730
|
}
|
|
541
731
|
|
|
542
|
-
|
|
543
|
-
SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
|
|
732
|
+
std::unique_ptr<FastScanCodeScanner> IndexIVFRaBitQFastScan::make_knn_scanner(
|
|
544
733
|
bool is_max,
|
|
545
|
-
int /* impl */,
|
|
546
734
|
idx_t n,
|
|
547
735
|
idx_t k,
|
|
548
736
|
float* distances,
|
|
549
737
|
idx_t* labels,
|
|
550
|
-
const IDSelector*
|
|
551
|
-
|
|
552
|
-
const
|
|
553
|
-
const
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
if (is_max) {
|
|
557
|
-
return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
|
|
558
|
-
this, n, k, distances, labels, &context, is_multibit);
|
|
559
|
-
} else {
|
|
560
|
-
return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
|
|
561
|
-
this, n, k, distances, labels, &context, is_multibit);
|
|
562
|
-
}
|
|
563
|
-
}
|
|
564
|
-
|
|
565
|
-
/*********************************************************
|
|
566
|
-
* IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler implementation
|
|
567
|
-
*********************************************************/
|
|
568
|
-
|
|
569
|
-
template <class C>
|
|
570
|
-
IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
|
|
571
|
-
const IndexIVFRaBitQFastScan* idx,
|
|
572
|
-
size_t nq_val,
|
|
573
|
-
size_t k_val,
|
|
574
|
-
float* distances,
|
|
575
|
-
int64_t* labels,
|
|
576
|
-
const FastScanDistancePostProcessing* ctx,
|
|
577
|
-
bool multibit)
|
|
578
|
-
: simd_result_handlers::ResultHandlerCompare<C, true>(
|
|
579
|
-
nq_val,
|
|
580
|
-
0,
|
|
581
|
-
nullptr),
|
|
582
|
-
index(idx),
|
|
583
|
-
heap_distances(distances),
|
|
584
|
-
heap_labels(labels),
|
|
585
|
-
nq(nq_val),
|
|
586
|
-
k(k_val),
|
|
587
|
-
context(ctx),
|
|
588
|
-
is_multibit(multibit),
|
|
589
|
-
storage_size(idx->compute_per_vector_storage_size()),
|
|
590
|
-
packed_block_size(((idx->M2 + 1) / 2) * idx->bbs),
|
|
591
|
-
full_block_size(idx->get_block_stride()),
|
|
592
|
-
packer(idx->get_CodePacker()) {
|
|
593
|
-
current_list_no = 0;
|
|
594
|
-
probe_indices.clear();
|
|
595
|
-
|
|
596
|
-
// Initialize heaps in constructor (standard pattern from HeapHandler)
|
|
597
|
-
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
598
|
-
float* heap_dis = heap_distances + q * k;
|
|
599
|
-
int64_t* heap_ids = heap_labels + q * k;
|
|
600
|
-
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
|
|
601
|
-
}
|
|
602
|
-
}
|
|
603
|
-
|
|
604
|
-
template <class C>
|
|
605
|
-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
606
|
-
size_t q,
|
|
607
|
-
size_t b,
|
|
608
|
-
simd16uint16 d0,
|
|
609
|
-
simd16uint16 d1) {
|
|
610
|
-
// Store the original local query index before adjust_with_origin changes it
|
|
611
|
-
size_t local_q = q;
|
|
612
|
-
this->adjust_with_origin(q, d0, d1);
|
|
613
|
-
|
|
614
|
-
ALIGNED(32) uint16_t d32tab[32];
|
|
615
|
-
d0.store(d32tab);
|
|
616
|
-
d1.store(d32tab + 16);
|
|
617
|
-
|
|
618
|
-
float* const heap_dis = heap_distances + q * k;
|
|
619
|
-
int64_t* const heap_ids = heap_labels + q * k;
|
|
620
|
-
|
|
621
|
-
FAISS_THROW_IF_NOT_FMT(
|
|
622
|
-
!probe_indices.empty() && local_q < probe_indices.size(),
|
|
623
|
-
"set_list_context() must be called before handle() - probe_indices size: %zu, local_q: %zu, global_q: %zu",
|
|
624
|
-
probe_indices.size(),
|
|
625
|
-
local_q,
|
|
626
|
-
q);
|
|
627
|
-
|
|
628
|
-
// Access query factors directly from array via ProcessingContext
|
|
629
|
-
if (!context || !context->query_factors) {
|
|
630
|
-
FAISS_THROW_MSG(
|
|
631
|
-
"Query factors not available: FastScanDistancePostProcessing with query_factors required");
|
|
632
|
-
}
|
|
633
|
-
|
|
634
|
-
// Use probe_rank from probe_indices for compact storage indexing
|
|
635
|
-
size_t probe_rank = probe_indices[local_q];
|
|
636
|
-
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
|
|
637
|
-
size_t storage_idx = q * nprobe + probe_rank;
|
|
638
|
-
|
|
639
|
-
const auto& query_factors = context->query_factors[storage_idx];
|
|
640
|
-
|
|
641
|
-
const float one_a =
|
|
642
|
-
this->normalizers ? (1.0f / this->normalizers[2 * q]) : 1.0f;
|
|
643
|
-
const float bias = this->normalizers ? this->normalizers[2 * q + 1] : 0.0f;
|
|
644
|
-
|
|
645
|
-
uint64_t idx_base = this->j0 + b * 32;
|
|
646
|
-
if (idx_base >= this->ntotal) {
|
|
647
|
-
return;
|
|
648
|
-
}
|
|
649
|
-
|
|
650
|
-
size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
|
|
651
|
-
|
|
652
|
-
// Stats tracking for two-stage search
|
|
653
|
-
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
654
|
-
// n_multibit_evaluations: candidates requiring full multi-bit distance
|
|
655
|
-
size_t local_1bit_evaluations = 0;
|
|
656
|
-
size_t local_multibit_evaluations = 0;
|
|
657
|
-
|
|
658
|
-
// Process each candidate vector in the SIMD batch
|
|
659
|
-
for (size_t j = 0; j < max_positions; j++) {
|
|
660
|
-
const int64_t result_id = this->adjust_id(b, j);
|
|
661
|
-
|
|
662
|
-
if (result_id < 0) {
|
|
663
|
-
continue;
|
|
664
|
-
}
|
|
665
|
-
|
|
666
|
-
const float normalized_distance = d32tab[j] * one_a + bias;
|
|
667
|
-
|
|
668
|
-
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
669
|
-
list_codes_ptr,
|
|
670
|
-
idx_base + j,
|
|
671
|
-
index->bbs,
|
|
672
|
-
packed_block_size,
|
|
673
|
-
full_block_size,
|
|
674
|
-
storage_size);
|
|
675
|
-
|
|
676
|
-
if (is_multibit) {
|
|
677
|
-
// Track candidates actually considered for two-stage filtering
|
|
678
|
-
local_1bit_evaluations++;
|
|
679
|
-
|
|
680
|
-
// Multi-bit: use SignBitFactorsWithError and two-stage search
|
|
681
|
-
const SignBitFactorsWithError& full_factors =
|
|
682
|
-
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
683
|
-
|
|
684
|
-
// Compute 1-bit adjusted distance using shared helper
|
|
685
|
-
float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
|
|
686
|
-
normalized_distance,
|
|
687
|
-
full_factors,
|
|
688
|
-
query_factors,
|
|
689
|
-
index->centered,
|
|
690
|
-
index->qb,
|
|
691
|
-
index->d);
|
|
692
|
-
|
|
693
|
-
// Adaptive filtering: decide whether to compute full distance
|
|
694
|
-
const bool is_similarity =
|
|
695
|
-
index->metric_type == MetricType::METRIC_INNER_PRODUCT;
|
|
696
|
-
|
|
697
|
-
float g_error = query_factors.g_error;
|
|
698
|
-
|
|
699
|
-
bool should_refine = rabitq_utils::should_refine_candidate(
|
|
700
|
-
dist_1bit,
|
|
701
|
-
full_factors.f_error,
|
|
702
|
-
g_error,
|
|
703
|
-
heap_dis[0],
|
|
704
|
-
is_similarity);
|
|
705
|
-
if (should_refine) {
|
|
706
|
-
local_multibit_evaluations++;
|
|
707
|
-
|
|
708
|
-
// Compute local_offset: position within current inverted list
|
|
709
|
-
size_t local_offset = this->j0 + b * 32 + j;
|
|
710
|
-
|
|
711
|
-
// Compute full multi-bit distance
|
|
712
|
-
float dist_full = compute_full_multibit_distance(
|
|
713
|
-
result_id, local_q, q, local_offset);
|
|
714
|
-
|
|
715
|
-
// Update heap if this distance is better
|
|
716
|
-
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
717
|
-
heap_replace_top<Cfloat>(
|
|
718
|
-
k, heap_dis, heap_ids, dist_full, result_id);
|
|
719
|
-
nup++;
|
|
720
|
-
}
|
|
721
|
-
}
|
|
722
|
-
} else {
|
|
723
|
-
const auto& db_factors =
|
|
724
|
-
*reinterpret_cast<const SignBitFactors*>(base_ptr);
|
|
725
|
-
|
|
726
|
-
// Compute adjusted distance using shared helper
|
|
727
|
-
float adjusted_distance =
|
|
728
|
-
rabitq_utils::compute_1bit_adjusted_distance(
|
|
729
|
-
normalized_distance,
|
|
730
|
-
db_factors,
|
|
731
|
-
query_factors,
|
|
732
|
-
index->centered,
|
|
733
|
-
index->qb,
|
|
734
|
-
index->d);
|
|
735
|
-
|
|
736
|
-
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
737
|
-
heap_replace_top<Cfloat>(
|
|
738
|
-
k, heap_dis, heap_ids, adjusted_distance, result_id);
|
|
739
|
-
nup++;
|
|
740
|
-
}
|
|
741
|
-
}
|
|
742
|
-
}
|
|
743
|
-
|
|
744
|
-
// Update global stats atomically
|
|
745
|
-
#pragma omp atomic
|
|
746
|
-
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
|
|
747
|
-
#pragma omp atomic
|
|
748
|
-
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
|
|
749
|
-
}
|
|
750
|
-
|
|
751
|
-
template <class C>
|
|
752
|
-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
|
|
753
|
-
size_t list_no,
|
|
754
|
-
const std::vector<int>& probe_map) {
|
|
755
|
-
current_list_no = list_no;
|
|
756
|
-
probe_indices = probe_map;
|
|
757
|
-
list_codes_ptr = index->invlists->get_codes(list_no);
|
|
758
|
-
}
|
|
759
|
-
|
|
760
|
-
template <class C>
|
|
761
|
-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::begin(
|
|
762
|
-
const float* norms) {
|
|
763
|
-
this->normalizers = norms;
|
|
764
|
-
}
|
|
765
|
-
|
|
766
|
-
template <class C>
|
|
767
|
-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
|
|
768
|
-
#pragma omp parallel for
|
|
769
|
-
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
770
|
-
float* heap_dis = heap_distances + q * k;
|
|
771
|
-
int64_t* heap_ids = heap_labels + q * k;
|
|
772
|
-
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
|
|
773
|
-
}
|
|
774
|
-
}
|
|
775
|
-
|
|
776
|
-
template <class C>
|
|
777
|
-
float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
|
|
778
|
-
compute_full_multibit_distance(
|
|
779
|
-
size_t /*db_idx*/,
|
|
780
|
-
size_t local_q,
|
|
781
|
-
size_t global_q,
|
|
782
|
-
size_t local_offset) const {
|
|
783
|
-
const size_t ex_bits = index->rabitq.nb_bits - 1;
|
|
784
|
-
const size_t dim = index->d;
|
|
785
|
-
|
|
786
|
-
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
787
|
-
list_codes_ptr,
|
|
788
|
-
local_offset,
|
|
789
|
-
index->bbs,
|
|
790
|
-
packed_block_size,
|
|
791
|
-
full_block_size,
|
|
792
|
-
storage_size);
|
|
793
|
-
|
|
794
|
-
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
795
|
-
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
796
|
-
const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
|
|
797
|
-
base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
|
|
798
|
-
|
|
799
|
-
// Use local_q to access probe_indices (batch-local), global_q for storage
|
|
800
|
-
size_t probe_rank = probe_indices[local_q];
|
|
801
|
-
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
|
|
802
|
-
size_t storage_idx = global_q * nprobe + probe_rank;
|
|
803
|
-
const auto& query_factors = context->query_factors[storage_idx];
|
|
804
|
-
|
|
805
|
-
size_t list_no = current_list_no;
|
|
806
|
-
InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
|
|
807
|
-
|
|
808
|
-
std::vector<uint8_t> unpacked_code(index->code_size);
|
|
809
|
-
packer->unpack_1(list_codes.get(), local_offset, unpacked_code.data());
|
|
810
|
-
const uint8_t* sign_bits = unpacked_code.data();
|
|
811
|
-
|
|
812
|
-
return rabitq_utils::compute_full_multibit_distance(
|
|
813
|
-
sign_bits,
|
|
814
|
-
ex_code,
|
|
815
|
-
ex_fac,
|
|
816
|
-
query_factors.rotated_q.data(),
|
|
817
|
-
(index->metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
818
|
-
? query_factors.q_dot_c
|
|
819
|
-
: query_factors.qr_to_c_L2sqr,
|
|
820
|
-
dim,
|
|
821
|
-
ex_bits,
|
|
822
|
-
index->metric_type);
|
|
738
|
+
const IDSelector* sel,
|
|
739
|
+
int /*impl*/,
|
|
740
|
+
const FastScanDistancePostProcessing& context) const {
|
|
741
|
+
const bool is_multibit = (rabitq.nb_bits - 1) > 0;
|
|
742
|
+
return rabitq_ivf_make_knn_scanner(
|
|
743
|
+
is_max, this, n, k, distances, labels, sel, &context, is_multibit);
|
|
823
744
|
}
|
|
824
745
|
|
|
825
746
|
/*********************************************************
|
|
@@ -829,139 +750,209 @@ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
|
|
|
829
750
|
namespace {
|
|
830
751
|
|
|
831
752
|
/// Provides IVF scanner interface using FastScan's SIMD batch processing.
|
|
753
|
+
/// Buffers are allocated once and reused across set_list + scan_codes calls.
|
|
832
754
|
struct IVFRaBitQFastScanScanner : InvertedListScanner {
|
|
833
|
-
|
|
755
|
+
using InvertedListScanner::scan_codes;
|
|
834
756
|
static constexpr size_t nq = 1;
|
|
835
757
|
|
|
836
758
|
const IndexIVFRaBitQFastScan& index;
|
|
759
|
+
const uint8_t qb;
|
|
760
|
+
const bool centered;
|
|
837
761
|
|
|
762
|
+
const float* xi = nullptr;
|
|
763
|
+
|
|
764
|
+
// Reusable buffers (allocated once in constructor)
|
|
838
765
|
AlignedTable<uint8_t> dis_tables;
|
|
839
766
|
AlignedTable<uint16_t> biases;
|
|
840
|
-
/// [scale, offset] for converting uint16 to float
|
|
841
767
|
std::array<float, 2> normalizers{};
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
768
|
+
AlignedTable<float> lut_float;
|
|
769
|
+
std::vector<float> rotated_q;
|
|
770
|
+
std::vector<float> centroid_buf;
|
|
845
771
|
QueryFactorsData query_factors;
|
|
846
772
|
FastScanDistancePostProcessing context;
|
|
773
|
+
std::vector<int> probe_map;
|
|
774
|
+
std::vector<float> mins_buf;
|
|
847
775
|
|
|
776
|
+
// Distance computer for distance_to_code (created in set_list)
|
|
848
777
|
std::unique_ptr<FlatCodesDistanceComputer> dc;
|
|
849
|
-
std::vector<float> centroid;
|
|
850
778
|
|
|
851
779
|
IVFRaBitQFastScanScanner(
|
|
852
|
-
const IndexIVFRaBitQFastScan&
|
|
853
|
-
bool
|
|
854
|
-
const IDSelector*
|
|
855
|
-
|
|
856
|
-
|
|
780
|
+
const IndexIVFRaBitQFastScan& index_in,
|
|
781
|
+
bool store_pairs_in,
|
|
782
|
+
const IDSelector* sel_in,
|
|
783
|
+
uint8_t qb_in,
|
|
784
|
+
bool centered_in)
|
|
785
|
+
: InvertedListScanner(store_pairs_in, sel_in),
|
|
786
|
+
index(index_in),
|
|
787
|
+
qb(qb_in),
|
|
788
|
+
centered(centered_in),
|
|
789
|
+
lut_float(16 * index_in.M),
|
|
790
|
+
rotated_q(index_in.d),
|
|
791
|
+
centroid_buf(index_in.d),
|
|
792
|
+
probe_map({0}),
|
|
793
|
+
mins_buf(index_in.M) {
|
|
794
|
+
this->keep_max = is_similarity_metric(index_in.metric_type);
|
|
795
|
+
this->code_size = index_in.code_size;
|
|
796
|
+
|
|
797
|
+
// Pre-allocate output tables for single probe
|
|
798
|
+
dis_tables.resize(16 * index_in.M2);
|
|
799
|
+
biases.resize(1);
|
|
800
|
+
|
|
801
|
+
// Set up context once
|
|
802
|
+
context.query_factors = &query_factors;
|
|
803
|
+
context.nprobe = 1;
|
|
804
|
+
context.qb = qb;
|
|
805
|
+
context.centered = centered;
|
|
857
806
|
}
|
|
858
807
|
|
|
859
808
|
void set_query(const float* query) override {
|
|
860
809
|
this->xi = query;
|
|
861
810
|
}
|
|
862
811
|
|
|
863
|
-
void set_list(idx_t
|
|
864
|
-
this->list_no =
|
|
812
|
+
void set_list(idx_t list_no_in, float /*coarse_dis_in*/) override {
|
|
813
|
+
this->list_no = list_no_in;
|
|
814
|
+
|
|
815
|
+
index.compute_residual_LUT(
|
|
816
|
+
xi,
|
|
817
|
+
list_no_in,
|
|
818
|
+
query_factors,
|
|
819
|
+
lut_float.get(),
|
|
820
|
+
qb,
|
|
821
|
+
centered,
|
|
822
|
+
rotated_q,
|
|
823
|
+
centroid_buf);
|
|
824
|
+
|
|
825
|
+
// Single-probe quantization (simplified inline, no OMP, no 3D)
|
|
826
|
+
const size_t M = index.M;
|
|
827
|
+
const size_t M2 = index.M2;
|
|
828
|
+
const size_t ksub = index.ksub;
|
|
829
|
+
|
|
830
|
+
float max_span = -HUGE_VAL;
|
|
831
|
+
float max_dis = 0;
|
|
832
|
+
float b = 0;
|
|
833
|
+
float* mins = mins_buf.data();
|
|
865
834
|
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
835
|
+
for (size_t m = 0; m < M; m++) {
|
|
836
|
+
const float* tab = lut_float.get() + m * ksub;
|
|
837
|
+
float mn = tab[0], mx = tab[0];
|
|
838
|
+
for (size_t s = 1; s < ksub; s++) {
|
|
839
|
+
mn = std::min(mn, tab[s]);
|
|
840
|
+
mx = std::max(mx, tab[s]);
|
|
841
|
+
}
|
|
842
|
+
mins[m] = mn;
|
|
843
|
+
float span = mx - mn;
|
|
844
|
+
max_span = std::max(max_span, span);
|
|
845
|
+
max_dis += span;
|
|
846
|
+
b += mn;
|
|
847
|
+
}
|
|
876
848
|
|
|
877
|
-
|
|
878
|
-
|
|
849
|
+
float a = std::min(255.0f / max_span, 65535.0f / max_dis);
|
|
850
|
+
uint8_t* out = dis_tables.get();
|
|
851
|
+
for (size_t m = 0; m < M; m++) {
|
|
852
|
+
const float* tab = lut_float.get() + m * ksub;
|
|
853
|
+
for (size_t s = 0; s < ksub; s++) {
|
|
854
|
+
out[m * ksub + s] = static_cast<uint8_t>(
|
|
855
|
+
std::roundf(a * (tab[s] - mins[m])));
|
|
856
|
+
}
|
|
857
|
+
}
|
|
858
|
+
memset(out + M * ksub, 0, (M2 - M) * ksub);
|
|
859
|
+
biases[0] = 0;
|
|
860
|
+
normalizers[0] = a;
|
|
861
|
+
normalizers[1] = b;
|
|
879
862
|
|
|
880
|
-
//
|
|
881
|
-
|
|
882
|
-
index.quantizer->reconstruct(list_no, centroid.data());
|
|
863
|
+
// Create distance computer (reuses centroid_buf from
|
|
864
|
+
// compute_residual_LUT)
|
|
883
865
|
dc.reset(index.rabitq.get_distance_computer(
|
|
884
|
-
|
|
866
|
+
qb, centroid_buf.data(), centered));
|
|
885
867
|
dc->set_query(xi);
|
|
886
868
|
}
|
|
887
869
|
|
|
888
870
|
float distance_to_code(const uint8_t* code) const override {
|
|
889
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
890
|
-
dc,
|
|
891
|
-
"set_query and set_list must be called before distance_to_code");
|
|
892
871
|
return dc->distance_to_code(code);
|
|
893
872
|
}
|
|
894
873
|
|
|
895
|
-
public:
|
|
896
874
|
size_t scan_codes(
|
|
897
875
|
size_t ntotal,
|
|
898
876
|
const uint8_t* codes,
|
|
899
877
|
const idx_t* ids,
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
size_t k
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
k,
|
|
914
|
-
curr_dists.data(),
|
|
915
|
-
curr_labels.data(),
|
|
916
|
-
sel,
|
|
917
|
-
context,
|
|
918
|
-
&normalizers[0]));
|
|
919
|
-
|
|
920
|
-
int qmap1[1] = {0};
|
|
921
|
-
handler->q_map = qmap1;
|
|
922
|
-
handler->begin(&normalizers[0]);
|
|
923
|
-
|
|
924
|
-
const uint8_t* LUT = dis_tables.get();
|
|
925
|
-
handler->dbias = biases.get();
|
|
926
|
-
handler->ntotal = ntotal;
|
|
927
|
-
handler->id_map = ids;
|
|
928
|
-
|
|
929
|
-
// RaBitQ needs list context for factor lookup
|
|
930
|
-
std::vector<int> probe_map = {0};
|
|
931
|
-
handler->set_list_context(list_no, probe_map);
|
|
932
|
-
|
|
933
|
-
pq4_accumulate_loop(
|
|
934
|
-
1,
|
|
935
|
-
roundup(ntotal, index.bbs),
|
|
936
|
-
index.bbs,
|
|
937
|
-
static_cast<int>(index.M2),
|
|
938
|
-
codes,
|
|
939
|
-
LUT,
|
|
940
|
-
*handler,
|
|
941
|
-
nullptr,
|
|
942
|
-
index.get_block_stride());
|
|
943
|
-
|
|
944
|
-
// Combine results across iterations
|
|
945
|
-
handler->end();
|
|
946
|
-
if (keep_max) {
|
|
947
|
-
minheap_addn(
|
|
878
|
+
ResultHandler& result_handler) const override {
|
|
879
|
+
auto scan_with_heap = [&](auto* heap_handler) -> size_t {
|
|
880
|
+
const size_t k = heap_handler->k;
|
|
881
|
+
if (k == 0) {
|
|
882
|
+
return 0;
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
std::vector<float> curr_dists(k, result_handler.threshold);
|
|
886
|
+
std::vector<idx_t> curr_labels(k, -1);
|
|
887
|
+
|
|
888
|
+
auto scanner = index.make_knn_scanner(
|
|
889
|
+
!keep_max,
|
|
890
|
+
nq,
|
|
948
891
|
k,
|
|
949
|
-
distances,
|
|
950
|
-
labels,
|
|
951
892
|
curr_dists.data(),
|
|
952
893
|
curr_labels.data(),
|
|
953
|
-
|
|
894
|
+
sel,
|
|
895
|
+
0,
|
|
896
|
+
context);
|
|
897
|
+
auto* handler = scanner->handler();
|
|
898
|
+
|
|
899
|
+
int qmap1[1] = {0};
|
|
900
|
+
handler->q_map = qmap1;
|
|
901
|
+
handler->begin(&normalizers[0]);
|
|
902
|
+
handler->dbias = biases.get();
|
|
903
|
+
handler->ntotal = ntotal;
|
|
904
|
+
handler->id_map = ids;
|
|
905
|
+
|
|
906
|
+
handler->set_list_context(list_no, probe_map);
|
|
907
|
+
if (!handler->list_codes_ptr) {
|
|
908
|
+
handler->list_codes_ptr = codes;
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
scanner->accumulate_loop(
|
|
912
|
+
1,
|
|
913
|
+
roundup(ntotal, index.bbs),
|
|
914
|
+
index.bbs,
|
|
915
|
+
static_cast<int>(index.M2),
|
|
916
|
+
codes,
|
|
917
|
+
dis_tables.get(),
|
|
918
|
+
0,
|
|
919
|
+
index.get_block_stride());
|
|
920
|
+
|
|
921
|
+
const size_t scan_cnt = handler->count_scanned_rows();
|
|
922
|
+
handler->end();
|
|
923
|
+
|
|
924
|
+
result_handler.stats.scan_cnt += scan_cnt;
|
|
925
|
+
size_t nup = 0;
|
|
926
|
+
for (size_t j = 0; j < k; j++) {
|
|
927
|
+
if (curr_labels[j] < 0) {
|
|
928
|
+
continue;
|
|
929
|
+
}
|
|
930
|
+
if (result_handler.add_result(curr_dists[j], curr_labels[j])) {
|
|
931
|
+
result_handler.stats.nheap_updates++;
|
|
932
|
+
nup++;
|
|
933
|
+
}
|
|
934
|
+
}
|
|
935
|
+
return nup;
|
|
936
|
+
};
|
|
937
|
+
|
|
938
|
+
if (!keep_max) {
|
|
939
|
+
using C = CMax<float, idx_t>;
|
|
940
|
+
if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
|
|
941
|
+
&result_handler)) {
|
|
942
|
+
return scan_with_heap(heap_handler);
|
|
943
|
+
}
|
|
954
944
|
} else {
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
curr_labels.data(),
|
|
961
|
-
k);
|
|
945
|
+
using C = CMin<float, idx_t>;
|
|
946
|
+
if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
|
|
947
|
+
&result_handler)) {
|
|
948
|
+
return scan_with_heap(heap_handler);
|
|
949
|
+
}
|
|
962
950
|
}
|
|
963
951
|
|
|
964
|
-
|
|
952
|
+
FAISS_THROW_MSG(
|
|
953
|
+
"IVFRaBitQFastScanScanner::scan_codes requires "
|
|
954
|
+
"HeapResultHandler; custom ResultHandler scan is not supported "
|
|
955
|
+
"by this optimized scanner");
|
|
965
956
|
}
|
|
966
957
|
};
|
|
967
958
|
|
|
@@ -970,8 +961,16 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner {
|
|
|
970
961
|
InvertedListScanner* IndexIVFRaBitQFastScan::get_InvertedListScanner(
|
|
971
962
|
bool store_pairs,
|
|
972
963
|
const IDSelector* sel,
|
|
973
|
-
const IVFSearchParameters*) const {
|
|
974
|
-
|
|
964
|
+
const IVFSearchParameters* search_params_in) const {
|
|
965
|
+
uint8_t used_qb = qb;
|
|
966
|
+
bool used_centered = centered;
|
|
967
|
+
if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
|
|
968
|
+
search_params_in)) {
|
|
969
|
+
used_qb = params->qb;
|
|
970
|
+
used_centered = params->centered;
|
|
971
|
+
}
|
|
972
|
+
return new IVFRaBitQFastScanScanner(
|
|
973
|
+
*this, store_pairs, sel, used_qb, used_centered);
|
|
975
974
|
}
|
|
976
975
|
|
|
977
976
|
} // namespace faiss
|