faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
11
|
#include <faiss/impl/RaBitQUtils.h>
|
|
12
12
|
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
13
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
13
14
|
#include <faiss/utils/distances.h>
|
|
14
15
|
#include <faiss/utils/rabitq_simd.h>
|
|
15
16
|
|
|
@@ -27,10 +28,13 @@ using rabitq_utils::QueryFactorsData;
|
|
|
27
28
|
using rabitq_utils::SignBitFactors;
|
|
28
29
|
using rabitq_utils::SignBitFactorsWithError;
|
|
29
30
|
|
|
30
|
-
RaBitQuantizer::RaBitQuantizer(
|
|
31
|
-
|
|
31
|
+
RaBitQuantizer::RaBitQuantizer(
|
|
32
|
+
size_t d_in,
|
|
33
|
+
MetricType metric,
|
|
34
|
+
size_t nb_bits_in)
|
|
35
|
+
: Quantizer(d_in, 0), // code_size will be set below
|
|
32
36
|
metric_type{metric},
|
|
33
|
-
nb_bits{
|
|
37
|
+
nb_bits{nb_bits_in} {
|
|
34
38
|
// Validate nb_bits range
|
|
35
39
|
FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
|
|
36
40
|
|
|
@@ -38,7 +42,7 @@ RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
|
|
|
38
42
|
code_size = compute_code_size(d, nb_bits);
|
|
39
43
|
}
|
|
40
44
|
|
|
41
|
-
size_t RaBitQuantizer::compute_code_size(size_t
|
|
45
|
+
size_t RaBitQuantizer::compute_code_size(size_t d_in, size_t num_bits) const {
|
|
42
46
|
// Validate inputs
|
|
43
47
|
FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
|
|
44
48
|
|
|
@@ -50,7 +54,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
|
|
|
50
54
|
// Layout for multi-bit: [binary_code: (d+7)/8
|
|
51
55
|
// bytes][SignBitFactorsWithError: 12 bytes]
|
|
52
56
|
// factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
|
|
53
|
-
size_t base_size = (
|
|
57
|
+
size_t base_size = (d_in + 7) / 8 +
|
|
54
58
|
(ex_bits == 0 ? sizeof(SignBitFactors)
|
|
55
59
|
: sizeof(SignBitFactorsWithError));
|
|
56
60
|
|
|
@@ -58,7 +62,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
|
|
|
58
62
|
// Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
|
|
59
63
|
size_t ex_size = 0;
|
|
60
64
|
if (ex_bits > 0) {
|
|
61
|
-
ex_size = (
|
|
65
|
+
ex_size = (d_in * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
|
|
62
66
|
}
|
|
63
67
|
|
|
64
68
|
return base_size + ex_size;
|
|
@@ -92,7 +96,7 @@ void RaBitQuantizer::compute_codes_core(
|
|
|
92
96
|
|
|
93
97
|
// Compute codes
|
|
94
98
|
#pragma omp parallel for if (n > 1000)
|
|
95
|
-
for (int64_t i = 0; i < n; i++) {
|
|
99
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
96
100
|
// Pointer to this vector's code
|
|
97
101
|
uint8_t* code = codes + i * code_size;
|
|
98
102
|
|
|
@@ -186,7 +190,7 @@ void RaBitQuantizer::decode_core(
|
|
|
186
190
|
const size_t ex_bits = nb_bits - 1;
|
|
187
191
|
|
|
188
192
|
#pragma omp parallel for if (n > 1000)
|
|
189
|
-
for (int64_t i = 0; i < n; i++) {
|
|
193
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
190
194
|
const uint8_t* code = codes + i * code_size;
|
|
191
195
|
|
|
192
196
|
// split the code into parts
|
|
@@ -218,162 +222,159 @@ void RaBitQuantizer::decode_core(
|
|
|
218
222
|
|
|
219
223
|
namespace {
|
|
220
224
|
|
|
225
|
+
// Distance computers templatized on SIMDLevel to avoid per-call dynamic
|
|
226
|
+
// dispatch. The SIMDLevel is baked in at construction time via
|
|
227
|
+
// get_distance_computer, so virtual calls through the base class go
|
|
228
|
+
// directly to the SIMD-specialized code.
|
|
229
|
+
|
|
230
|
+
template <SIMDLevel SL>
|
|
221
231
|
struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
|
|
222
232
|
// the rotated query (qr - c)
|
|
223
233
|
std::vector<float> rotated_q;
|
|
224
234
|
// some additional numbers for the query
|
|
225
235
|
QueryFactorsData query_fac;
|
|
226
236
|
|
|
227
|
-
RaBitQDistanceComputerNotQ();
|
|
237
|
+
RaBitQDistanceComputerNotQ() = default;
|
|
228
238
|
|
|
229
239
|
// Compute distance using only 1-bit codes (fast)
|
|
230
|
-
float distance_to_code_1bit(const uint8_t* code) override
|
|
240
|
+
float distance_to_code_1bit(const uint8_t* code) override {
|
|
241
|
+
FAISS_ASSERT(code != nullptr);
|
|
242
|
+
FAISS_ASSERT(
|
|
243
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
244
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
245
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
231
246
|
|
|
232
|
-
|
|
233
|
-
|
|
247
|
+
// split the code into parts
|
|
248
|
+
const uint8_t* binary_data = code;
|
|
234
249
|
|
|
235
|
-
|
|
236
|
-
|
|
250
|
+
// Cast to appropriate type based on nb_bits
|
|
251
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
252
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
253
|
+
// f_error
|
|
254
|
+
size_t ex_bits = nb_bits - 1;
|
|
255
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
256
|
+
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
257
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
258
|
+
code + (d + 7) / 8);
|
|
259
|
+
|
|
260
|
+
// this is the baseline code
|
|
261
|
+
//
|
|
262
|
+
// compute <q,o> using floats
|
|
263
|
+
float dot_qo = 0;
|
|
264
|
+
// It was a willful decision (after the discussion) to not to pre-cache
|
|
265
|
+
// the sum of all bits, just in order to reduce the overhead per
|
|
266
|
+
// vector.
|
|
267
|
+
uint64_t sum_q = 0;
|
|
268
|
+
|
|
269
|
+
for (size_t i = 0; i < d; i++) {
|
|
270
|
+
// Extract i-th bit
|
|
271
|
+
bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
|
|
272
|
+
// accumulate dp
|
|
273
|
+
dot_qo += bit ? rotated_q[i] : 0;
|
|
274
|
+
// accumulate sum-of-bits
|
|
275
|
+
sum_q += bit ? 1 : 0;
|
|
276
|
+
}
|
|
237
277
|
|
|
238
|
-
|
|
278
|
+
// Apply query factors
|
|
279
|
+
float final_dot =
|
|
280
|
+
query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
|
|
239
281
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
245
|
-
FAISS_ASSERT(rotated_q.size() == d);
|
|
246
|
-
|
|
247
|
-
// split the code into parts
|
|
248
|
-
const uint8_t* binary_data = code;
|
|
249
|
-
|
|
250
|
-
// Cast to appropriate type based on nb_bits
|
|
251
|
-
// For 1-bit: use SignBitFactors (8 bytes)
|
|
252
|
-
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
253
|
-
// f_error
|
|
254
|
-
size_t ex_bits = nb_bits - 1;
|
|
255
|
-
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
256
|
-
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
257
|
-
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
258
|
-
code + (d + 7) / 8);
|
|
259
|
-
|
|
260
|
-
// this is the baseline code
|
|
261
|
-
//
|
|
262
|
-
// compute <q,o> using floats
|
|
263
|
-
float dot_qo = 0;
|
|
264
|
-
// It was a willful decision (after the discussion) to not to pre-cache
|
|
265
|
-
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
266
|
-
uint64_t sum_q = 0;
|
|
267
|
-
|
|
268
|
-
for (size_t i = 0; i < d; i++) {
|
|
269
|
-
// Extract i-th bit
|
|
270
|
-
bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
|
|
271
|
-
// accumulate dp
|
|
272
|
-
dot_qo += bit ? rotated_q[i] : 0;
|
|
273
|
-
// accumulate sum-of-bits
|
|
274
|
-
sum_q += bit ? 1 : 0;
|
|
275
|
-
}
|
|
282
|
+
// pre_dist = ||or - c||^2 + ||qr - c||^2 -
|
|
283
|
+
// 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
|
|
284
|
+
float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
|
|
285
|
+
2 * base_fac->dp_multiplier * final_dot;
|
|
276
286
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
if (metric_type == MetricType::METRIC_L2) {
|
|
287
|
-
// ||or - q||^ 2
|
|
288
|
-
return pre_dist;
|
|
289
|
-
} else {
|
|
290
|
-
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
291
|
-
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
287
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
288
|
+
// ||or - q||^ 2
|
|
289
|
+
return pre_dist;
|
|
290
|
+
} else {
|
|
291
|
+
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
292
|
+
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
293
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
294
|
+
}
|
|
292
295
|
}
|
|
293
|
-
}
|
|
294
296
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
297
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
298
|
+
float distance_to_code_full(const uint8_t* code) override {
|
|
299
|
+
FAISS_ASSERT(code != nullptr);
|
|
300
|
+
FAISS_ASSERT(
|
|
301
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
302
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
303
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
301
304
|
|
|
302
|
-
|
|
305
|
+
size_t ex_bits = nb_bits - 1;
|
|
303
306
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
+
if (ex_bits == 0) {
|
|
308
|
+
// No ex-bits, just return 1-bit distance
|
|
309
|
+
return distance_to_code_1bit(code);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
// Extract pointers to code sections
|
|
313
|
+
const uint8_t* binary_data = code;
|
|
314
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
315
|
+
const uint8_t* ex_code = code + offset;
|
|
316
|
+
const ExtraBitsFactors* ex_fac =
|
|
317
|
+
reinterpret_cast<const ExtraBitsFactors*>(
|
|
318
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
319
|
+
|
|
320
|
+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
321
|
+
? query_fac.q_dot_c
|
|
322
|
+
: query_fac.qr_to_c_L2sqr;
|
|
323
|
+
return rabitq_utils::compute_full_multibit_distance<SL>(
|
|
324
|
+
binary_data,
|
|
325
|
+
ex_code,
|
|
326
|
+
*ex_fac,
|
|
327
|
+
rotated_q.data(),
|
|
328
|
+
qr_base,
|
|
329
|
+
d,
|
|
330
|
+
ex_bits,
|
|
331
|
+
metric_type);
|
|
307
332
|
}
|
|
308
333
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
// Call shared utility directly with rotated_q pointer
|
|
317
|
-
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
318
|
-
? query_fac.q_dot_c
|
|
319
|
-
: query_fac.qr_to_c_L2sqr;
|
|
320
|
-
return rabitq_utils::compute_full_multibit_distance(
|
|
321
|
-
binary_data,
|
|
322
|
-
ex_code,
|
|
323
|
-
*ex_fac,
|
|
324
|
-
rotated_q.data(),
|
|
325
|
-
qr_base,
|
|
326
|
-
d,
|
|
327
|
-
ex_bits,
|
|
328
|
-
metric_type);
|
|
329
|
-
}
|
|
334
|
+
void set_query(const float* x) override {
|
|
335
|
+
q = x;
|
|
336
|
+
FAISS_ASSERT(x != nullptr);
|
|
337
|
+
FAISS_ASSERT(
|
|
338
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
339
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
330
340
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
// compute the distance from the query to the centroid
|
|
339
|
-
if (centroid != nullptr) {
|
|
340
|
-
query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
|
|
341
|
-
} else {
|
|
342
|
-
query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
|
|
343
|
-
}
|
|
341
|
+
// compute the distance from the query to the centroid
|
|
342
|
+
if (centroid != nullptr) {
|
|
343
|
+
query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
|
|
344
|
+
} else {
|
|
345
|
+
query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
|
|
346
|
+
}
|
|
344
347
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
348
|
+
// subtract c, obtain P^(-1)(qr - c)
|
|
349
|
+
rotated_q.resize(d);
|
|
350
|
+
for (size_t i = 0; i < d; i++) {
|
|
351
|
+
rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
|
|
352
|
+
}
|
|
350
353
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
354
|
+
// Compute g_error = ||qr - c|| (L2 norm of rotated query)
|
|
355
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
354
356
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
+
// compute some numbers — do not quantize the query
|
|
358
|
+
const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
357
359
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
}
|
|
360
|
+
float sum_q = 0;
|
|
361
|
+
for (size_t i = 0; i < d; i++) {
|
|
362
|
+
sum_q += rotated_q[i];
|
|
363
|
+
}
|
|
363
364
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
365
|
+
query_fac.c1 = 2 * inv_d;
|
|
366
|
+
query_fac.c2 = 0;
|
|
367
|
+
query_fac.c34 = sum_q * inv_d;
|
|
367
368
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
369
|
+
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
370
|
+
query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
|
|
371
|
+
query_fac.q_dot_c =
|
|
372
|
+
centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
|
|
373
|
+
}
|
|
373
374
|
}
|
|
374
|
-
}
|
|
375
|
+
};
|
|
375
376
|
|
|
376
|
-
|
|
377
|
+
template <SIMDLevel SL>
|
|
377
378
|
struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
|
|
378
379
|
// the rotated and quantized query (qr - c)
|
|
379
380
|
std::vector<float> rotated_q;
|
|
@@ -391,176 +392,188 @@ struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
|
|
|
391
392
|
// the smallest value divisible by 8 that is not smaller than dim
|
|
392
393
|
size_t popcount_aligned_dim = 0;
|
|
393
394
|
|
|
394
|
-
RaBitQDistanceComputerQ();
|
|
395
|
+
RaBitQDistanceComputerQ() = default;
|
|
395
396
|
|
|
396
397
|
// Compute distance using only 1-bit codes (fast)
|
|
397
|
-
float distance_to_code_1bit(const uint8_t* code) override
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
398
|
+
float distance_to_code_1bit(const uint8_t* code) override {
|
|
399
|
+
FAISS_ASSERT(code != nullptr);
|
|
400
|
+
FAISS_ASSERT(
|
|
401
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
402
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
401
403
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
+
// split the code into parts
|
|
405
|
+
size_t size = (d + 7) / 8;
|
|
406
|
+
const uint8_t* binary_data = code;
|
|
404
407
|
|
|
405
|
-
|
|
408
|
+
// Cast to appropriate type based on nb_bits
|
|
409
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
410
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which
|
|
411
|
+
// includes f_error
|
|
412
|
+
size_t ex_bits = nb_bits - 1;
|
|
413
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
414
|
+
? reinterpret_cast<const SignBitFactors*>(code + size)
|
|
415
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
416
|
+
|
|
417
|
+
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
418
|
+
float final_dot = 0;
|
|
419
|
+
if (centered) {
|
|
420
|
+
int64_t int_dot = ((1 << qb) - 1) * d;
|
|
421
|
+
// See RaBitDistanceComputerNotQ::distance_to_code() for
|
|
422
|
+
// baseline code.
|
|
423
|
+
int_dot -= 2 *
|
|
424
|
+
rabitq::bitwise_xor_dot_product<SL>(
|
|
425
|
+
rearranged_rotated_qq.data(),
|
|
426
|
+
binary_data,
|
|
427
|
+
size,
|
|
428
|
+
qb);
|
|
429
|
+
final_dot += int_dot * query_fac.int_dot_scale;
|
|
430
|
+
} else {
|
|
431
|
+
auto dot_qo = rabitq::bitwise_and_dot_product<SL>(
|
|
432
|
+
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
433
|
+
// It was a willful decision (after the discussion) to not to
|
|
434
|
+
// pre-cache the sum of all bits, just in order to reduce the
|
|
435
|
+
// overhead per vector.
|
|
436
|
+
// process 64-bit popcounts
|
|
437
|
+
auto sum_q = rabitq::popcount<SL>(binary_data, size);
|
|
438
|
+
// dot-product itself
|
|
439
|
+
final_dot += query_fac.c1 * dot_qo;
|
|
440
|
+
// normalizer coefficients
|
|
441
|
+
final_dot += query_fac.c2 * sum_q;
|
|
442
|
+
// normalizer coefficients
|
|
443
|
+
final_dot -= query_fac.c34;
|
|
444
|
+
}
|
|
406
445
|
|
|
407
|
-
float
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
(metric_type == MetricType::METRIC_L2 ||
|
|
411
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
446
|
+
const float pre_dist = base_fac->or_minus_c_l2sqr +
|
|
447
|
+
query_fac.qr_to_c_L2sqr -
|
|
448
|
+
2 * base_fac->dp_multiplier * final_dot;
|
|
412
449
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
size_t ex_bits = nb_bits - 1;
|
|
422
|
-
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
423
|
-
? reinterpret_cast<const SignBitFactors*>(code + size)
|
|
424
|
-
: reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
425
|
-
|
|
426
|
-
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
427
|
-
float final_dot = 0;
|
|
428
|
-
if (centered) {
|
|
429
|
-
int64_t int_dot = ((1 << qb) - 1) * d;
|
|
430
|
-
// See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
|
|
431
|
-
int_dot -= 2 *
|
|
432
|
-
rabitq::bitwise_xor_dot_product(
|
|
433
|
-
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
434
|
-
final_dot += int_dot * query_fac.int_dot_scale;
|
|
435
|
-
} else {
|
|
436
|
-
auto dot_qo = rabitq::bitwise_and_dot_product(
|
|
437
|
-
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
438
|
-
// It was a willful decision (after the discussion) to not to pre-cache
|
|
439
|
-
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
440
|
-
// process 64-bit popcounts
|
|
441
|
-
auto sum_q = rabitq::popcount(binary_data, size);
|
|
442
|
-
// dot-product itself
|
|
443
|
-
final_dot += query_fac.c1 * dot_qo;
|
|
444
|
-
// normalizer coefficients
|
|
445
|
-
final_dot += query_fac.c2 * sum_q;
|
|
446
|
-
// normalizer coefficients
|
|
447
|
-
final_dot -= query_fac.c34;
|
|
450
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
451
|
+
// ||or - q||^ 2
|
|
452
|
+
return pre_dist;
|
|
453
|
+
} else {
|
|
454
|
+
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
455
|
+
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
456
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
457
|
+
}
|
|
448
458
|
}
|
|
449
459
|
|
|
450
|
-
//
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
return pre_dist;
|
|
458
|
-
} else {
|
|
459
|
-
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
460
|
-
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
461
|
-
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
462
|
-
}
|
|
463
|
-
}
|
|
460
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
461
|
+
float distance_to_code_full(const uint8_t* code) override {
|
|
462
|
+
FAISS_ASSERT(code != nullptr);
|
|
463
|
+
FAISS_ASSERT(
|
|
464
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
465
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
466
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
464
467
|
|
|
465
|
-
|
|
466
|
-
FAISS_ASSERT(code != nullptr);
|
|
467
|
-
FAISS_ASSERT(
|
|
468
|
-
(metric_type == MetricType::METRIC_L2 ||
|
|
469
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
470
|
-
FAISS_ASSERT(rotated_q.size() == d);
|
|
468
|
+
size_t ex_bits = nb_bits - 1;
|
|
471
469
|
|
|
472
|
-
|
|
470
|
+
if (ex_bits == 0) {
|
|
471
|
+
// No ex-bits, just return 1-bit distance
|
|
472
|
+
return distance_to_code_1bit(code);
|
|
473
|
+
}
|
|
473
474
|
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
475
|
+
// Extract pointers to code sections
|
|
476
|
+
const uint8_t* binary_data = code;
|
|
477
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
478
|
+
const uint8_t* ex_code = code + offset;
|
|
479
|
+
const ExtraBitsFactors* ex_fac =
|
|
480
|
+
reinterpret_cast<const ExtraBitsFactors*>(
|
|
481
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
482
|
+
|
|
483
|
+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
484
|
+
? query_fac.q_dot_c
|
|
485
|
+
: query_fac.qr_to_c_L2sqr;
|
|
486
|
+
return rabitq_utils::compute_full_multibit_distance<SL>(
|
|
487
|
+
binary_data,
|
|
488
|
+
ex_code,
|
|
489
|
+
*ex_fac,
|
|
490
|
+
rotated_q.data(),
|
|
491
|
+
qr_base,
|
|
492
|
+
d,
|
|
493
|
+
ex_bits,
|
|
494
|
+
metric_type);
|
|
477
495
|
}
|
|
478
496
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
497
|
+
void set_query(const float* x) override {
|
|
498
|
+
q = x;
|
|
499
|
+
FAISS_ASSERT(x != nullptr);
|
|
500
|
+
FAISS_ASSERT(
|
|
501
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
502
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
503
|
+
FAISS_THROW_IF_NOT(qb <= 8);
|
|
504
|
+
FAISS_THROW_IF_NOT(qb > 0);
|
|
505
|
+
|
|
506
|
+
// Use shared utilities for core query factor computation
|
|
507
|
+
// rotated_q is populated directly by compute_query_factors as an
|
|
508
|
+
// output parameter
|
|
509
|
+
query_fac = rabitq_utils::compute_query_factors(
|
|
510
|
+
x,
|
|
511
|
+
d,
|
|
512
|
+
centroid,
|
|
513
|
+
qb,
|
|
514
|
+
centered,
|
|
515
|
+
metric_type,
|
|
516
|
+
rotated_q,
|
|
517
|
+
rotated_qq);
|
|
518
|
+
|
|
519
|
+
// Compute g_error (query norm for lower bound computation)
|
|
520
|
+
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
521
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
522
|
+
|
|
523
|
+
// Rearrange the query vector for SIMD operations
|
|
524
|
+
// (RaBitQuantizer-specific)
|
|
525
|
+
popcount_aligned_dim = ((d + 7) / 8) * 8;
|
|
526
|
+
size_t offset = (d + 7) / 8;
|
|
527
|
+
|
|
528
|
+
rearranged_rotated_qq.resize(offset * qb);
|
|
529
|
+
std::fill(
|
|
530
|
+
rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
|
|
531
|
+
|
|
532
|
+
for (size_t idim = 0; idim < d; idim++) {
|
|
533
|
+
for (size_t iv = 0; iv < qb; iv++) {
|
|
534
|
+
const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
|
|
535
|
+
rearranged_rotated_qq[iv * offset + idim / 8] |=
|
|
536
|
+
bit ? (1 << (idim % 8)) : 0;
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
};
|
|
500
541
|
|
|
501
542
|
// Use shared constant from RaBitQUtils
|
|
502
543
|
using rabitq_utils::Z_MAX_BY_QB;
|
|
503
544
|
|
|
504
|
-
void RaBitQDistanceComputerQ::set_query(const float* x) {
|
|
505
|
-
q = x;
|
|
506
|
-
FAISS_ASSERT(x != nullptr);
|
|
507
|
-
FAISS_ASSERT(
|
|
508
|
-
(metric_type == MetricType::METRIC_L2 ||
|
|
509
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
510
|
-
FAISS_THROW_IF_NOT(qb <= 8);
|
|
511
|
-
FAISS_THROW_IF_NOT(qb > 0);
|
|
512
|
-
|
|
513
|
-
// Use shared utilities for core query factor computation
|
|
514
|
-
// rotated_q is populated directly by compute_query_factors as an output
|
|
515
|
-
// parameter
|
|
516
|
-
query_fac = rabitq_utils::compute_query_factors(
|
|
517
|
-
x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
|
|
518
|
-
|
|
519
|
-
// Compute g_error (query norm for lower bound computation)
|
|
520
|
-
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
521
|
-
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
522
|
-
|
|
523
|
-
// Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
|
|
524
|
-
popcount_aligned_dim = ((d + 7) / 8) * 8;
|
|
525
|
-
size_t offset = (d + 7) / 8;
|
|
526
|
-
|
|
527
|
-
rearranged_rotated_qq.resize(offset * qb);
|
|
528
|
-
std::fill(rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
|
|
529
|
-
|
|
530
|
-
for (size_t idim = 0; idim < d; idim++) {
|
|
531
|
-
for (size_t iv = 0; iv < qb; iv++) {
|
|
532
|
-
const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
|
|
533
|
-
rearranged_rotated_qq[iv * offset + idim / 8] |=
|
|
534
|
-
bit ? (1 << (idim % 8)) : 0;
|
|
535
|
-
}
|
|
536
|
-
}
|
|
537
|
-
}
|
|
538
|
-
|
|
539
545
|
} // anonymous namespace
|
|
540
546
|
|
|
541
547
|
FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
|
|
542
548
|
uint8_t qb,
|
|
543
549
|
const float* centroid_in,
|
|
544
550
|
bool centered) const {
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
551
|
+
// Dispatch on SIMDLevel once here so the distance computer methods
|
|
552
|
+
// call the SIMD-specialized rabitq functions directly (no per-call
|
|
553
|
+
// with_simd_level overhead).
|
|
554
|
+
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
|
|
555
|
+
[&]<SIMDLevel SL>() -> FlatCodesDistanceComputer* {
|
|
556
|
+
if (qb == 0) {
|
|
557
|
+
auto dc =
|
|
558
|
+
std::make_unique<RaBitQDistanceComputerNotQ<SL>>();
|
|
559
|
+
dc->metric_type = metric_type;
|
|
560
|
+
dc->d = d;
|
|
561
|
+
dc->centroid = centroid_in;
|
|
562
|
+
dc->nb_bits = nb_bits;
|
|
563
|
+
|
|
564
|
+
return dc.release();
|
|
565
|
+
} else {
|
|
566
|
+
auto dc = std::make_unique<RaBitQDistanceComputerQ<SL>>();
|
|
567
|
+
dc->metric_type = metric_type;
|
|
568
|
+
dc->d = d;
|
|
569
|
+
dc->centroid = centroid_in;
|
|
570
|
+
dc->qb = qb;
|
|
571
|
+
dc->centered = centered;
|
|
572
|
+
dc->nb_bits = nb_bits;
|
|
573
|
+
|
|
574
|
+
return dc.release();
|
|
575
|
+
}
|
|
576
|
+
});
|
|
564
577
|
}
|
|
565
578
|
|
|
566
579
|
} // namespace faiss
|