faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -7,10 +7,10 @@
|
|
|
7
7
|
|
|
8
8
|
#include <faiss/IndexRaBitQFastScan.h>
|
|
9
9
|
#include <faiss/impl/CodePackerRaBitQ.h>
|
|
10
|
-
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
11
10
|
#include <faiss/impl/RaBitQUtils.h>
|
|
12
11
|
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
13
|
-
#include <faiss/impl/
|
|
12
|
+
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
|
|
13
|
+
#include <faiss/impl/fast_scan/fast_scan.h>
|
|
14
14
|
#include <faiss/utils/utils.h>
|
|
15
15
|
#include <algorithm>
|
|
16
16
|
#include <cmath>
|
|
@@ -28,13 +28,13 @@ size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
|
28
28
|
IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
|
|
29
29
|
|
|
30
30
|
IndexRaBitQFastScan::IndexRaBitQFastScan(
|
|
31
|
-
idx_t
|
|
31
|
+
idx_t d_in,
|
|
32
32
|
MetricType metric,
|
|
33
|
-
int
|
|
33
|
+
int bbs_in,
|
|
34
34
|
uint8_t nb_bits)
|
|
35
|
-
: rabitq(
|
|
35
|
+
: rabitq(d_in, metric, nb_bits) {
|
|
36
36
|
// RaBitQ-specific validation
|
|
37
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
37
|
+
FAISS_THROW_IF_NOT_MSG(d_in > 0, "Dimension must be positive");
|
|
38
38
|
FAISS_THROW_IF_NOT_MSG(
|
|
39
39
|
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
|
|
40
40
|
"RaBitQ FastScan only supports L2 and Inner Product metrics");
|
|
@@ -43,18 +43,19 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(
|
|
|
43
43
|
|
|
44
44
|
// RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
|
|
45
45
|
// Each FastScan sub-quantizer handles 4 RaBitQ dimensions
|
|
46
|
-
const size_t M_fastscan = (
|
|
46
|
+
const size_t M_fastscan = (d_in + 3) / 4;
|
|
47
47
|
constexpr size_t nbits_fastscan = 4;
|
|
48
48
|
|
|
49
49
|
// init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
|
|
50
|
-
init_fastscan(
|
|
50
|
+
init_fastscan(
|
|
51
|
+
static_cast<int>(d_in), M_fastscan, nbits_fastscan, metric, bbs_in);
|
|
51
52
|
|
|
52
53
|
// Compute code_size directly using RaBitQuantizer
|
|
53
|
-
code_size = rabitq.compute_code_size(
|
|
54
|
+
code_size = rabitq.compute_code_size(d_in, nb_bits);
|
|
54
55
|
|
|
55
56
|
// Set RaBitQ-specific parameters
|
|
56
57
|
qb = 8;
|
|
57
|
-
center.resize(
|
|
58
|
+
center.resize(d_in, 0.0f);
|
|
58
59
|
}
|
|
59
60
|
|
|
60
61
|
CodePacker* IndexRaBitQFastScan::get_CodePacker() const {
|
|
@@ -102,7 +103,7 @@ size_t IndexRaBitQFastScan::remove_ids(const IDSelector& sel) {
|
|
|
102
103
|
return nremove;
|
|
103
104
|
}
|
|
104
105
|
|
|
105
|
-
IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int
|
|
106
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs_in)
|
|
106
107
|
: rabitq(orig.rabitq) {
|
|
107
108
|
// RaBitQ-specific validation
|
|
108
109
|
FAISS_THROW_IF_NOT_MSG(orig.d > 0, "Dimension must be positive");
|
|
@@ -122,7 +123,7 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
|
122
123
|
M_fastscan,
|
|
123
124
|
nbits_fastscan,
|
|
124
125
|
orig.metric_type,
|
|
125
|
-
|
|
126
|
+
bbs_in);
|
|
126
127
|
|
|
127
128
|
code_size = rabitq.compute_code_size(d, rabitq.nb_bits);
|
|
128
129
|
|
|
@@ -148,7 +149,7 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
|
148
149
|
const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
|
|
149
150
|
uint8_t* fs_code = fastscan_codes.get() + i * code_size;
|
|
150
151
|
|
|
151
|
-
for (size_t j = 0; j < orig.d; j++) {
|
|
152
|
+
for (size_t j = 0; j < static_cast<size_t>(orig.d); j++) {
|
|
152
153
|
const size_t orig_byte_idx = j / 8;
|
|
153
154
|
const size_t orig_bit_offset = j % 8;
|
|
154
155
|
const bool bit_value =
|
|
@@ -197,13 +198,13 @@ void IndexRaBitQFastScan::train(idx_t n, const float* x) {
|
|
|
197
198
|
// compute a centroid
|
|
198
199
|
std::vector<float> centroid(d, 0);
|
|
199
200
|
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
200
|
-
for (size_t j = 0; j < d; j++) {
|
|
201
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
201
202
|
centroid[j] += x[i * d + j];
|
|
202
203
|
}
|
|
203
204
|
}
|
|
204
205
|
|
|
205
206
|
if (n != 0) {
|
|
206
|
-
for (size_t j = 0; j < d; j++) {
|
|
207
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
207
208
|
centroid[j] /= (float)n;
|
|
208
209
|
}
|
|
209
210
|
}
|
|
@@ -279,9 +280,11 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
|
|
|
279
280
|
ntotal += n;
|
|
280
281
|
}
|
|
281
282
|
|
|
282
|
-
void IndexRaBitQFastScan::compute_codes(
|
|
283
|
-
|
|
284
|
-
|
|
283
|
+
void IndexRaBitQFastScan::compute_codes(
|
|
284
|
+
uint8_t* out_codes,
|
|
285
|
+
idx_t n,
|
|
286
|
+
const float* x) const {
|
|
287
|
+
FAISS_ASSERT(out_codes != nullptr);
|
|
285
288
|
FAISS_ASSERT(x != nullptr);
|
|
286
289
|
FAISS_ASSERT(
|
|
287
290
|
(metric_type == MetricType::METRIC_L2 ||
|
|
@@ -296,23 +299,23 @@ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
|
296
299
|
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
297
300
|
const size_t ex_code_size = (d * ex_bits + 7) / 8;
|
|
298
301
|
|
|
299
|
-
memset(
|
|
302
|
+
memset(out_codes, 0, n * code_size);
|
|
300
303
|
|
|
301
304
|
#pragma omp parallel for if (n > 1000)
|
|
302
305
|
for (int64_t i = 0; i < n; i++) {
|
|
303
|
-
uint8_t* const code =
|
|
306
|
+
uint8_t* const code = out_codes + i * code_size;
|
|
304
307
|
const float* const x_row = x + i * d;
|
|
305
308
|
|
|
306
309
|
// Compute residual once, reuse for both sign bits and ex-bits
|
|
307
310
|
std::vector<float> residual(d);
|
|
308
|
-
for (size_t j = 0; j < d; j++) {
|
|
311
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
309
312
|
const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
|
|
310
313
|
residual[j] = x_row[j] - centroid_val;
|
|
311
314
|
}
|
|
312
315
|
|
|
313
316
|
// Pack sign bits directly into FastScan format using precomputed
|
|
314
317
|
// residual
|
|
315
|
-
for (size_t j = 0; j < d; j++) {
|
|
318
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
316
319
|
if (residual[j] > 0.0f) {
|
|
317
320
|
rabitq_utils::set_bit_fastscan(code, j);
|
|
318
321
|
}
|
|
@@ -412,7 +415,7 @@ void IndexRaBitQFastScan::compute_float_LUT(
|
|
|
412
415
|
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
413
416
|
const size_t dim_idx = dim_start + dim_offset;
|
|
414
417
|
|
|
415
|
-
if (dim_idx < d) {
|
|
418
|
+
if (dim_idx < static_cast<size_t>(d)) {
|
|
416
419
|
const bool db_bit = (code_val >> dim_offset) & 1;
|
|
417
420
|
const float query_value = rotated_qq[dim_idx];
|
|
418
421
|
|
|
@@ -447,7 +450,8 @@ void IndexRaBitQFastScan::compute_float_LUT(
|
|
|
447
450
|
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
448
451
|
const size_t dim_idx = dim_start + dim_offset;
|
|
449
452
|
|
|
450
|
-
if (dim_idx < d &&
|
|
453
|
+
if (dim_idx < static_cast<size_t>(d) &&
|
|
454
|
+
((code_val >> dim_offset) & 1)) {
|
|
451
455
|
inner_product += rotated_qq[dim_idx];
|
|
452
456
|
popcount++;
|
|
453
457
|
}
|
|
@@ -463,12 +467,16 @@ void IndexRaBitQFastScan::compute_float_LUT(
|
|
|
463
467
|
}
|
|
464
468
|
}
|
|
465
469
|
|
|
470
|
+
size_t IndexRaBitQFastScan::fast_scan_code_size() const {
|
|
471
|
+
return (d + 7) / 8;
|
|
472
|
+
}
|
|
473
|
+
|
|
466
474
|
void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
467
475
|
const {
|
|
468
476
|
const float* centroid_in =
|
|
469
477
|
(center.data() == nullptr) ? nullptr : center.data();
|
|
470
|
-
const uint8_t*
|
|
471
|
-
FAISS_ASSERT(
|
|
478
|
+
const uint8_t* input_codes = bytes;
|
|
479
|
+
FAISS_ASSERT(input_codes != nullptr);
|
|
472
480
|
FAISS_ASSERT(x != nullptr);
|
|
473
481
|
|
|
474
482
|
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
@@ -477,7 +485,7 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
477
485
|
#pragma omp parallel for if (n > 1000)
|
|
478
486
|
for (int64_t i = 0; i < n; i++) {
|
|
479
487
|
// Access code using correct FastScan format
|
|
480
|
-
const uint8_t* code =
|
|
488
|
+
const uint8_t* code = input_codes + i * code_size;
|
|
481
489
|
|
|
482
490
|
// Extract factors directly from embedded codes
|
|
483
491
|
const uint8_t* factors_ptr = code + bit_pattern_size;
|
|
@@ -485,7 +493,7 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
485
493
|
reinterpret_cast<const rabitq_utils::SignBitFactors*>(
|
|
486
494
|
factors_ptr);
|
|
487
495
|
|
|
488
|
-
for (size_t j = 0; j < d; j++) {
|
|
496
|
+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
|
|
489
497
|
// Use RaBitQUtils for consistent bit extraction
|
|
490
498
|
bool bit_value = rabitq_utils::extract_bit_fastscan(code, j);
|
|
491
499
|
float bit = bit_value ? 1.0f : 0.0f;
|
|
@@ -522,264 +530,20 @@ void IndexRaBitQFastScan::search(
|
|
|
522
530
|
}
|
|
523
531
|
}
|
|
524
532
|
|
|
525
|
-
|
|
526
|
-
template <class C, bool with_id_map>
|
|
527
|
-
RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
|
|
528
|
-
const IndexRaBitQFastScan* index,
|
|
529
|
-
size_t nq_val,
|
|
530
|
-
size_t k_val,
|
|
531
|
-
float* distances,
|
|
532
|
-
int64_t* labels,
|
|
533
|
-
const IDSelector* sel_in,
|
|
534
|
-
const FastScanDistancePostProcessing& ctx,
|
|
535
|
-
bool multi_bit)
|
|
536
|
-
: RHC(nq_val, index->ntotal, sel_in),
|
|
537
|
-
rabitq_index(index),
|
|
538
|
-
heap_distances(distances),
|
|
539
|
-
heap_labels(labels),
|
|
540
|
-
nq(nq_val),
|
|
541
|
-
k(k_val),
|
|
542
|
-
context(ctx),
|
|
543
|
-
is_multi_bit(multi_bit),
|
|
544
|
-
storage_size(index->compute_per_vector_storage_size()),
|
|
545
|
-
packed_block_size(((index->M2 + 1) / 2) * index->bbs),
|
|
546
|
-
full_block_size(index->get_block_stride()),
|
|
547
|
-
packer(index->get_CodePacker()) {
|
|
548
|
-
// Initialize heaps for all queries in constructor
|
|
549
|
-
// This allows us to support direct normalizer assignment
|
|
550
|
-
#pragma omp parallel for if (nq > 100)
|
|
551
|
-
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
552
|
-
float* heap_dis = heap_distances + q * k;
|
|
553
|
-
int64_t* heap_ids = heap_labels + q * k;
|
|
554
|
-
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
|
|
555
|
-
}
|
|
556
|
-
}
|
|
557
|
-
|
|
558
|
-
template <class C, bool with_id_map>
|
|
559
|
-
void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
560
|
-
size_t q,
|
|
561
|
-
size_t b,
|
|
562
|
-
simd16uint16 d0,
|
|
563
|
-
simd16uint16 d1) {
|
|
564
|
-
ALIGNED(32) uint16_t d32tab[32];
|
|
565
|
-
d0.store(d32tab);
|
|
566
|
-
d1.store(d32tab + 16);
|
|
567
|
-
|
|
568
|
-
// Get heap pointers and query factors (computed once per batch)
|
|
569
|
-
float* const heap_dis = heap_distances + q * k;
|
|
570
|
-
int64_t* const heap_ids = heap_labels + q * k;
|
|
571
|
-
|
|
572
|
-
// Access query factors from query_factors pointer
|
|
573
|
-
rabitq_utils::QueryFactorsData query_factors_data = {};
|
|
574
|
-
if (context.query_factors != nullptr) {
|
|
575
|
-
query_factors_data = context.query_factors[q];
|
|
576
|
-
}
|
|
577
|
-
|
|
578
|
-
// Compute normalizers once per batch
|
|
579
|
-
const float one_a = normalizers ? (1.0f / normalizers[2 * q]) : 1.0f;
|
|
580
|
-
const float bias = normalizers ? normalizers[2 * q + 1] : 0.0f;
|
|
581
|
-
|
|
582
|
-
// Compute loop bounds to avoid redundant bounds checking
|
|
583
|
-
const size_t base_db_idx = this->j0 + b * 32;
|
|
584
|
-
const size_t max_vectors = (base_db_idx < rabitq_index->ntotal)
|
|
585
|
-
? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
|
|
586
|
-
: 0;
|
|
587
|
-
|
|
588
|
-
// Compute block auxiliary region base pointer once per batch.
|
|
589
|
-
// Since bbs=32, each batch of 32 vectors aligns to one block.
|
|
590
|
-
const size_t block_idx = base_db_idx / rabitq_index->bbs;
|
|
591
|
-
const uint8_t* aux_base = rabitq_index->codes.get() +
|
|
592
|
-
block_idx * full_block_size + packed_block_size;
|
|
593
|
-
|
|
594
|
-
// Stats tracking for multi-bit two-stage search only
|
|
595
|
-
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
596
|
-
// n_multibit_evaluations: candidates requiring full multi-bit distance
|
|
597
|
-
size_t local_1bit_evaluations = 0;
|
|
598
|
-
size_t local_multibit_evaluations = 0;
|
|
599
|
-
|
|
600
|
-
// Process distances in batch
|
|
601
|
-
for (size_t i = 0; i < max_vectors; i++) {
|
|
602
|
-
const size_t db_idx = base_db_idx + i;
|
|
603
|
-
|
|
604
|
-
// Normalize distance from LUT lookup
|
|
605
|
-
const float normalized_distance = d32tab[i] * one_a + bias;
|
|
606
|
-
|
|
607
|
-
// Access factors from block auxiliary region
|
|
608
|
-
const uint8_t* base_ptr = aux_base + i * storage_size;
|
|
609
|
-
|
|
610
|
-
if (is_multi_bit) {
|
|
611
|
-
// Track candidates actually considered for two-stage filtering
|
|
612
|
-
local_1bit_evaluations++;
|
|
613
|
-
|
|
614
|
-
const SignBitFactorsWithError& full_factors =
|
|
615
|
-
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
616
|
-
|
|
617
|
-
float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
|
|
618
|
-
normalized_distance,
|
|
619
|
-
full_factors,
|
|
620
|
-
query_factors_data,
|
|
621
|
-
rabitq_index->centered,
|
|
622
|
-
rabitq_index->qb,
|
|
623
|
-
rabitq_index->d);
|
|
624
|
-
|
|
625
|
-
// Adaptive filtering: decide whether to compute full distance
|
|
626
|
-
const bool is_similarity = rabitq_index->metric_type ==
|
|
627
|
-
MetricType::METRIC_INNER_PRODUCT;
|
|
628
|
-
bool should_refine = rabitq_utils::should_refine_candidate(
|
|
629
|
-
dist_1bit,
|
|
630
|
-
full_factors.f_error,
|
|
631
|
-
context.query_factors ? context.query_factors[q].g_error
|
|
632
|
-
: 0.0f,
|
|
633
|
-
heap_dis[0],
|
|
634
|
-
is_similarity);
|
|
635
|
-
|
|
636
|
-
if (should_refine) {
|
|
637
|
-
local_multibit_evaluations++;
|
|
638
|
-
float dist_full = compute_full_multibit_distance(db_idx, q);
|
|
639
|
-
|
|
640
|
-
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
641
|
-
heap_replace_top<Cfloat>(
|
|
642
|
-
k, heap_dis, heap_ids, dist_full, db_idx);
|
|
643
|
-
}
|
|
644
|
-
}
|
|
645
|
-
} else {
|
|
646
|
-
const rabitq_utils::SignBitFactors& db_factors =
|
|
647
|
-
*reinterpret_cast<const rabitq_utils::SignBitFactors*>(
|
|
648
|
-
base_ptr);
|
|
649
|
-
|
|
650
|
-
float adjusted_distance =
|
|
651
|
-
rabitq_utils::compute_1bit_adjusted_distance(
|
|
652
|
-
normalized_distance,
|
|
653
|
-
db_factors,
|
|
654
|
-
query_factors_data,
|
|
655
|
-
rabitq_index->centered,
|
|
656
|
-
rabitq_index->qb,
|
|
657
|
-
rabitq_index->d);
|
|
658
|
-
|
|
659
|
-
// Add to heap if better than current worst
|
|
660
|
-
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
661
|
-
heap_replace_top<Cfloat>(
|
|
662
|
-
k, heap_dis, heap_ids, adjusted_distance, db_idx);
|
|
663
|
-
}
|
|
664
|
-
}
|
|
665
|
-
}
|
|
666
|
-
|
|
667
|
-
// Update global stats atomically
|
|
668
|
-
#pragma omp atomic
|
|
669
|
-
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
|
|
670
|
-
#pragma omp atomic
|
|
671
|
-
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
|
|
672
|
-
}
|
|
673
|
-
|
|
674
|
-
template <class C, bool with_id_map>
|
|
675
|
-
void RaBitQHeapHandler<C, with_id_map>::begin(const float* norms) {
|
|
676
|
-
normalizers = norms;
|
|
677
|
-
// Heap initialization is now done in constructor
|
|
678
|
-
}
|
|
679
|
-
|
|
680
|
-
template <class C, bool with_id_map>
|
|
681
|
-
void RaBitQHeapHandler<C, with_id_map>::end() {
|
|
682
|
-
// Reorder final results
|
|
683
|
-
#pragma omp parallel for if (nq > 100)
|
|
684
|
-
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
685
|
-
float* heap_dis = heap_distances + q * k;
|
|
686
|
-
int64_t* heap_ids = heap_labels + q * k;
|
|
687
|
-
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
|
|
688
|
-
}
|
|
689
|
-
}
|
|
690
|
-
|
|
691
|
-
template <class C, bool with_id_map>
|
|
692
|
-
float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
|
|
693
|
-
float dist_1bit,
|
|
694
|
-
size_t db_idx,
|
|
695
|
-
size_t q) const {
|
|
696
|
-
// Access f_error from block auxiliary region
|
|
697
|
-
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
698
|
-
rabitq_index->codes.get(),
|
|
699
|
-
db_idx,
|
|
700
|
-
rabitq_index->bbs,
|
|
701
|
-
packed_block_size,
|
|
702
|
-
full_block_size,
|
|
703
|
-
storage_size);
|
|
704
|
-
const SignBitFactorsWithError& db_factors =
|
|
705
|
-
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
706
|
-
float f_error = db_factors.f_error;
|
|
707
|
-
|
|
708
|
-
// Get g_error from query factors (query-dependent error term)
|
|
709
|
-
float g_error = 0.0f;
|
|
710
|
-
if (context.query_factors != nullptr) {
|
|
711
|
-
g_error = context.query_factors[q].g_error;
|
|
712
|
-
}
|
|
713
|
-
|
|
714
|
-
// Compute error adjustment: f_error * g_error
|
|
715
|
-
float error_adjustment = f_error * g_error;
|
|
716
|
-
|
|
717
|
-
return dist_1bit - error_adjustment;
|
|
718
|
-
}
|
|
719
|
-
|
|
720
|
-
template <class C, bool with_id_map>
|
|
721
|
-
float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
|
|
722
|
-
size_t db_idx,
|
|
723
|
-
size_t q) const {
|
|
724
|
-
const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
|
|
725
|
-
const size_t dim = rabitq_index->d;
|
|
726
|
-
|
|
727
|
-
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
728
|
-
rabitq_index->codes.get(),
|
|
729
|
-
db_idx,
|
|
730
|
-
rabitq_index->bbs,
|
|
731
|
-
packed_block_size,
|
|
732
|
-
full_block_size,
|
|
733
|
-
storage_size);
|
|
734
|
-
|
|
735
|
-
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
736
|
-
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
737
|
-
const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
|
|
738
|
-
base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
|
|
739
|
-
|
|
740
|
-
// Get query factors reference (avoid copying)
|
|
741
|
-
const rabitq_utils::QueryFactorsData& query_factors =
|
|
742
|
-
context.query_factors[q];
|
|
743
|
-
|
|
744
|
-
// Get sign bits from FastScan packed format
|
|
745
|
-
std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
|
|
746
|
-
packer->unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
|
|
747
|
-
const uint8_t* sign_bits = unpacked_code.data();
|
|
748
|
-
|
|
749
|
-
return rabitq_utils::compute_full_multibit_distance(
|
|
750
|
-
sign_bits,
|
|
751
|
-
ex_code,
|
|
752
|
-
ex_fac,
|
|
753
|
-
query_factors.rotated_q.data(),
|
|
754
|
-
(rabitq_index->metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
755
|
-
? query_factors.q_dot_c
|
|
756
|
-
: query_factors.qr_to_c_L2sqr,
|
|
757
|
-
dim,
|
|
758
|
-
ex_bits,
|
|
759
|
-
rabitq_index->metric_type);
|
|
760
|
-
}
|
|
533
|
+
std::unique_ptr<FastScanCodeScanner> IndexRaBitQFastScan::make_knn_scanner(
|
|
761
534
|
|
|
762
|
-
// Implementation of virtual make_knn_handler method
|
|
763
|
-
SIMDResultHandlerToFloat* IndexRaBitQFastScan::make_knn_handler(
|
|
764
535
|
bool is_max,
|
|
765
|
-
int /*impl*/,
|
|
766
536
|
idx_t n,
|
|
767
537
|
idx_t k,
|
|
768
538
|
size_t /*ntotal*/,
|
|
769
539
|
float* distances,
|
|
770
540
|
idx_t* labels,
|
|
771
541
|
const IDSelector* sel,
|
|
542
|
+
int /*impl*/,
|
|
772
543
|
const FastScanDistancePostProcessing& context) const {
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
if (is_max) {
|
|
777
|
-
return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
|
|
778
|
-
this, n, k, distances, labels, sel, context, multi_bit);
|
|
779
|
-
} else {
|
|
780
|
-
return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
|
|
781
|
-
this, n, k, distances, labels, sel, context, multi_bit);
|
|
782
|
-
}
|
|
544
|
+
const bool is_multi_bit = rabitq.nb_bits > 1;
|
|
545
|
+
return rabitq_make_knn_scanner(
|
|
546
|
+
this, is_max, n, k, distances, labels, sel, context, is_multi_bit);
|
|
783
547
|
}
|
|
784
548
|
|
|
785
549
|
} // namespace faiss
|