faiss 0.5.3 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +4 -4
- data/ext/faiss/index.cpp +63 -45
- data/ext/faiss/index_binary.cpp +37 -27
- data/ext/faiss/kmeans.cpp +9 -8
- data/ext/faiss/pca_matrix.cpp +9 -7
- data/ext/faiss/product_quantizer.cpp +13 -11
- data/ext/faiss/utils.cpp +4 -2
- data/ext/faiss/utils.h +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +214 -82
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +97 -249
- data/vendor/faiss/faiss/Clustering.h +18 -0
- data/vendor/faiss/faiss/IVFlib.cpp +67 -44
- data/vendor/faiss/faiss/Index.cpp +25 -12
- data/vendor/faiss/faiss/Index.h +26 -4
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
- data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
- data/vendor/faiss/faiss/IndexFastScan.h +35 -24
- data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
- data/vendor/faiss/faiss/IndexFlat.h +32 -14
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
- data/vendor/faiss/faiss/IndexHNSW.h +30 -14
- data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
- data/vendor/faiss/faiss/IndexIVF.h +47 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
- data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
- data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
- data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
- data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
- data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +13 -13
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +29 -6
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
- data/vendor/faiss/faiss/VectorTransform.h +39 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +55 -51
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
- data/vendor/faiss/faiss/impl/HNSW.h +21 -40
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
- data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
- data/vendor/faiss/faiss/impl/NSG.h +20 -10
- data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
- data/vendor/faiss/faiss/impl/Panorama.h +265 -78
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
- data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +115 -28
- data/vendor/faiss/faiss/index_io.h +53 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +507 -559
- data/vendor/faiss/faiss/utils/distances.h +118 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
- data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
- data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +21 -14
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +156 -42
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -10,8 +10,10 @@
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
11
|
#include <faiss/impl/RaBitQUtils.h>
|
|
12
12
|
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
13
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
13
14
|
#include <faiss/utils/distances.h>
|
|
14
15
|
#include <faiss/utils/rabitq_simd.h>
|
|
16
|
+
|
|
15
17
|
#include <algorithm>
|
|
16
18
|
#include <cmath>
|
|
17
19
|
#include <cstring>
|
|
@@ -26,10 +28,13 @@ using rabitq_utils::QueryFactorsData;
|
|
|
26
28
|
using rabitq_utils::SignBitFactors;
|
|
27
29
|
using rabitq_utils::SignBitFactorsWithError;
|
|
28
30
|
|
|
29
|
-
RaBitQuantizer::RaBitQuantizer(
|
|
30
|
-
|
|
31
|
+
RaBitQuantizer::RaBitQuantizer(
|
|
32
|
+
size_t d_in,
|
|
33
|
+
MetricType metric,
|
|
34
|
+
size_t nb_bits_in)
|
|
35
|
+
: Quantizer(d_in, 0), // code_size will be set below
|
|
31
36
|
metric_type{metric},
|
|
32
|
-
nb_bits{
|
|
37
|
+
nb_bits{nb_bits_in} {
|
|
33
38
|
// Validate nb_bits range
|
|
34
39
|
FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
|
|
35
40
|
|
|
@@ -37,7 +42,7 @@ RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
|
|
|
37
42
|
code_size = compute_code_size(d, nb_bits);
|
|
38
43
|
}
|
|
39
44
|
|
|
40
|
-
size_t RaBitQuantizer::compute_code_size(size_t
|
|
45
|
+
size_t RaBitQuantizer::compute_code_size(size_t d_in, size_t num_bits) const {
|
|
41
46
|
// Validate inputs
|
|
42
47
|
FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
|
|
43
48
|
|
|
@@ -49,7 +54,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
|
|
|
49
54
|
// Layout for multi-bit: [binary_code: (d+7)/8
|
|
50
55
|
// bytes][SignBitFactorsWithError: 12 bytes]
|
|
51
56
|
// factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
|
|
52
|
-
size_t base_size = (
|
|
57
|
+
size_t base_size = (d_in + 7) / 8 +
|
|
53
58
|
(ex_bits == 0 ? sizeof(SignBitFactors)
|
|
54
59
|
: sizeof(SignBitFactorsWithError));
|
|
55
60
|
|
|
@@ -57,13 +62,13 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
|
|
|
57
62
|
// Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
|
|
58
63
|
size_t ex_size = 0;
|
|
59
64
|
if (ex_bits > 0) {
|
|
60
|
-
ex_size = (
|
|
65
|
+
ex_size = (d_in * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
|
|
61
66
|
}
|
|
62
67
|
|
|
63
68
|
return base_size + ex_size;
|
|
64
69
|
}
|
|
65
70
|
|
|
66
|
-
void RaBitQuantizer::train(size_t n
|
|
71
|
+
void RaBitQuantizer::train(size_t /*n*/, const float* /*x*/) {
|
|
67
72
|
// does nothing
|
|
68
73
|
}
|
|
69
74
|
|
|
@@ -91,7 +96,7 @@ void RaBitQuantizer::compute_codes_core(
|
|
|
91
96
|
|
|
92
97
|
// Compute codes
|
|
93
98
|
#pragma omp parallel for if (n > 1000)
|
|
94
|
-
for (int64_t i = 0; i < n; i++) {
|
|
99
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
95
100
|
// Pointer to this vector's code
|
|
96
101
|
uint8_t* code = codes + i * code_size;
|
|
97
102
|
|
|
@@ -185,7 +190,7 @@ void RaBitQuantizer::decode_core(
|
|
|
185
190
|
const size_t ex_bits = nb_bits - 1;
|
|
186
191
|
|
|
187
192
|
#pragma omp parallel for if (n > 1000)
|
|
188
|
-
for (int64_t i = 0; i < n; i++) {
|
|
193
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
189
194
|
const uint8_t* code = codes + i * code_size;
|
|
190
195
|
|
|
191
196
|
// split the code into parts
|
|
@@ -215,183 +220,161 @@ void RaBitQuantizer::decode_core(
|
|
|
215
220
|
}
|
|
216
221
|
}
|
|
217
222
|
|
|
218
|
-
// Implementation of RaBitQDistanceComputer (declared in header)
|
|
219
|
-
|
|
220
|
-
float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
|
|
221
|
-
FAISS_ASSERT(code != nullptr);
|
|
222
|
-
|
|
223
|
-
// Compute estimated distance using 1-bit codes
|
|
224
|
-
float est_distance = distance_to_code_1bit(code);
|
|
225
|
-
|
|
226
|
-
// Extract f_error from the code
|
|
227
|
-
size_t size = (d + 7) / 8;
|
|
228
|
-
const SignBitFactorsWithError* base_fac =
|
|
229
|
-
reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
230
|
-
float f_error = base_fac->f_error;
|
|
231
|
-
|
|
232
|
-
// Compute proper lower bound using RaBitQ error formula:
|
|
233
|
-
// lower_bound = est_distance - f_error * g_error
|
|
234
|
-
// This guarantees: lower_bound ≤ true_distance
|
|
235
|
-
float lower_bound = est_distance - (f_error * g_error);
|
|
236
|
-
|
|
237
|
-
// Distance cannot be negative
|
|
238
|
-
return std::max(0.0f, lower_bound);
|
|
239
|
-
}
|
|
240
|
-
|
|
241
223
|
namespace {
|
|
242
224
|
|
|
225
|
+
// Distance computers templatized on SIMDLevel to avoid per-call dynamic
|
|
226
|
+
// dispatch. The SIMDLevel is baked in at construction time via
|
|
227
|
+
// get_distance_computer, so virtual calls through the base class go
|
|
228
|
+
// directly to the SIMD-specialized code.
|
|
229
|
+
|
|
230
|
+
template <SIMDLevel SL>
|
|
243
231
|
struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
|
|
244
232
|
// the rotated query (qr - c)
|
|
245
233
|
std::vector<float> rotated_q;
|
|
246
234
|
// some additional numbers for the query
|
|
247
235
|
QueryFactorsData query_fac;
|
|
248
236
|
|
|
249
|
-
RaBitQDistanceComputerNotQ();
|
|
237
|
+
RaBitQDistanceComputerNotQ() = default;
|
|
250
238
|
|
|
251
239
|
// Compute distance using only 1-bit codes (fast)
|
|
252
|
-
float distance_to_code_1bit(const uint8_t* code) override
|
|
240
|
+
float distance_to_code_1bit(const uint8_t* code) override {
|
|
241
|
+
FAISS_ASSERT(code != nullptr);
|
|
242
|
+
FAISS_ASSERT(
|
|
243
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
244
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
245
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
253
246
|
|
|
254
|
-
|
|
255
|
-
|
|
247
|
+
// split the code into parts
|
|
248
|
+
const uint8_t* binary_data = code;
|
|
256
249
|
|
|
257
|
-
|
|
258
|
-
|
|
250
|
+
// Cast to appropriate type based on nb_bits
|
|
251
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
252
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
253
|
+
// f_error
|
|
254
|
+
size_t ex_bits = nb_bits - 1;
|
|
255
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
256
|
+
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
257
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
258
|
+
code + (d + 7) / 8);
|
|
259
259
|
|
|
260
|
-
|
|
260
|
+
// this is the baseline code
|
|
261
|
+
//
|
|
262
|
+
// compute <q,o> using floats
|
|
263
|
+
float dot_qo = 0;
|
|
264
|
+
// It was a willful decision (after the discussion) to not to pre-cache
|
|
265
|
+
// the sum of all bits, just in order to reduce the overhead per
|
|
266
|
+
// vector.
|
|
267
|
+
uint64_t sum_q = 0;
|
|
268
|
+
|
|
269
|
+
for (size_t i = 0; i < d; i++) {
|
|
270
|
+
// Extract i-th bit
|
|
271
|
+
bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
|
|
272
|
+
// accumulate dp
|
|
273
|
+
dot_qo += bit ? rotated_q[i] : 0;
|
|
274
|
+
// accumulate sum-of-bits
|
|
275
|
+
sum_q += bit ? 1 : 0;
|
|
276
|
+
}
|
|
261
277
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
(metric_type == MetricType::METRIC_L2 ||
|
|
266
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
267
|
-
FAISS_ASSERT(rotated_q.size() == d);
|
|
268
|
-
|
|
269
|
-
// split the code into parts
|
|
270
|
-
const uint8_t* binary_data = code;
|
|
271
|
-
|
|
272
|
-
// Cast to appropriate type based on nb_bits
|
|
273
|
-
// For 1-bit: use SignBitFactors (8 bytes)
|
|
274
|
-
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
275
|
-
// f_error
|
|
276
|
-
size_t ex_bits = nb_bits - 1;
|
|
277
|
-
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
278
|
-
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
279
|
-
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
280
|
-
code + (d + 7) / 8);
|
|
281
|
-
|
|
282
|
-
// this is the baseline code
|
|
283
|
-
//
|
|
284
|
-
// compute <q,o> using floats
|
|
285
|
-
float dot_qo = 0;
|
|
286
|
-
// It was a willful decision (after the discussion) to not to pre-cache
|
|
287
|
-
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
288
|
-
uint64_t sum_q = 0;
|
|
289
|
-
|
|
290
|
-
for (size_t i = 0; i < d; i++) {
|
|
291
|
-
// Extract i-th bit
|
|
292
|
-
bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
|
|
293
|
-
// accumulate dp
|
|
294
|
-
dot_qo += bit ? rotated_q[i] : 0;
|
|
295
|
-
// accumulate sum-of-bits
|
|
296
|
-
sum_q += bit ? 1 : 0;
|
|
297
|
-
}
|
|
278
|
+
// Apply query factors
|
|
279
|
+
float final_dot =
|
|
280
|
+
query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
|
|
298
281
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
313
|
-
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
282
|
+
// pre_dist = ||or - c||^2 + ||qr - c||^2 -
|
|
283
|
+
// 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
|
|
284
|
+
float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
|
|
285
|
+
2 * base_fac->dp_multiplier * final_dot;
|
|
286
|
+
|
|
287
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
288
|
+
// ||or - q||^ 2
|
|
289
|
+
return pre_dist;
|
|
290
|
+
} else {
|
|
291
|
+
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
292
|
+
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
293
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
294
|
+
}
|
|
314
295
|
}
|
|
315
|
-
}
|
|
316
296
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
297
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
298
|
+
float distance_to_code_full(const uint8_t* code) override {
|
|
299
|
+
FAISS_ASSERT(code != nullptr);
|
|
300
|
+
FAISS_ASSERT(
|
|
301
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
302
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
303
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
323
304
|
|
|
324
|
-
|
|
305
|
+
size_t ex_bits = nb_bits - 1;
|
|
325
306
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
307
|
+
if (ex_bits == 0) {
|
|
308
|
+
// No ex-bits, just return 1-bit distance
|
|
309
|
+
return distance_to_code_1bit(code);
|
|
310
|
+
}
|
|
330
311
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
312
|
+
// Extract pointers to code sections
|
|
313
|
+
const uint8_t* binary_data = code;
|
|
314
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
315
|
+
const uint8_t* ex_code = code + offset;
|
|
316
|
+
const ExtraBitsFactors* ex_fac =
|
|
317
|
+
reinterpret_cast<const ExtraBitsFactors*>(
|
|
318
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
319
|
+
|
|
320
|
+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
321
|
+
? query_fac.q_dot_c
|
|
322
|
+
: query_fac.qr_to_c_L2sqr;
|
|
323
|
+
return rabitq_utils::compute_full_multibit_distance<SL>(
|
|
324
|
+
binary_data,
|
|
325
|
+
ex_code,
|
|
326
|
+
*ex_fac,
|
|
327
|
+
rotated_q.data(),
|
|
328
|
+
qr_base,
|
|
329
|
+
d,
|
|
330
|
+
ex_bits,
|
|
331
|
+
metric_type);
|
|
332
|
+
}
|
|
350
333
|
|
|
351
|
-
void
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
334
|
+
void set_query(const float* x) override {
|
|
335
|
+
q = x;
|
|
336
|
+
FAISS_ASSERT(x != nullptr);
|
|
337
|
+
FAISS_ASSERT(
|
|
338
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
339
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
357
340
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
341
|
+
// compute the distance from the query to the centroid
|
|
342
|
+
if (centroid != nullptr) {
|
|
343
|
+
query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
|
|
344
|
+
} else {
|
|
345
|
+
query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
|
|
346
|
+
}
|
|
364
347
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
348
|
+
// subtract c, obtain P^(-1)(qr - c)
|
|
349
|
+
rotated_q.resize(d);
|
|
350
|
+
for (size_t i = 0; i < d; i++) {
|
|
351
|
+
rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
|
|
352
|
+
}
|
|
370
353
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
354
|
+
// Compute g_error = ||qr - c|| (L2 norm of rotated query)
|
|
355
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
374
356
|
|
|
375
|
-
|
|
376
|
-
|
|
357
|
+
// compute some numbers — do not quantize the query
|
|
358
|
+
const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
377
359
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
}
|
|
360
|
+
float sum_q = 0;
|
|
361
|
+
for (size_t i = 0; i < d; i++) {
|
|
362
|
+
sum_q += rotated_q[i];
|
|
363
|
+
}
|
|
383
364
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
365
|
+
query_fac.c1 = 2 * inv_d;
|
|
366
|
+
query_fac.c2 = 0;
|
|
367
|
+
query_fac.c34 = sum_q * inv_d;
|
|
387
368
|
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
369
|
+
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
370
|
+
query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
|
|
371
|
+
query_fac.q_dot_c =
|
|
372
|
+
centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
|
|
373
|
+
}
|
|
391
374
|
}
|
|
392
|
-
}
|
|
375
|
+
};
|
|
393
376
|
|
|
394
|
-
|
|
377
|
+
template <SIMDLevel SL>
|
|
395
378
|
struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
|
|
396
379
|
// the rotated and quantized query (qr - c)
|
|
397
380
|
std::vector<float> rotated_q;
|
|
@@ -409,174 +392,188 @@ struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
|
|
|
409
392
|
// the smallest value divisible by 8 that is not smaller than dim
|
|
410
393
|
size_t popcount_aligned_dim = 0;
|
|
411
394
|
|
|
412
|
-
RaBitQDistanceComputerQ();
|
|
395
|
+
RaBitQDistanceComputerQ() = default;
|
|
413
396
|
|
|
414
397
|
// Compute distance using only 1-bit codes (fast)
|
|
415
|
-
float distance_to_code_1bit(const uint8_t* code) override
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
398
|
+
float distance_to_code_1bit(const uint8_t* code) override {
|
|
399
|
+
FAISS_ASSERT(code != nullptr);
|
|
400
|
+
FAISS_ASSERT(
|
|
401
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
402
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
419
403
|
|
|
420
|
-
|
|
421
|
-
|
|
404
|
+
// split the code into parts
|
|
405
|
+
size_t size = (d + 7) / 8;
|
|
406
|
+
const uint8_t* binary_data = code;
|
|
422
407
|
|
|
423
|
-
|
|
408
|
+
// Cast to appropriate type based on nb_bits
|
|
409
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
410
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which
|
|
411
|
+
// includes f_error
|
|
412
|
+
size_t ex_bits = nb_bits - 1;
|
|
413
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
414
|
+
? reinterpret_cast<const SignBitFactors*>(code + size)
|
|
415
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
416
|
+
|
|
417
|
+
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
418
|
+
float final_dot = 0;
|
|
419
|
+
if (centered) {
|
|
420
|
+
int64_t int_dot = ((1 << qb) - 1) * d;
|
|
421
|
+
// See RaBitDistanceComputerNotQ::distance_to_code() for
|
|
422
|
+
// baseline code.
|
|
423
|
+
int_dot -= 2 *
|
|
424
|
+
rabitq::bitwise_xor_dot_product<SL>(
|
|
425
|
+
rearranged_rotated_qq.data(),
|
|
426
|
+
binary_data,
|
|
427
|
+
size,
|
|
428
|
+
qb);
|
|
429
|
+
final_dot += int_dot * query_fac.int_dot_scale;
|
|
430
|
+
} else {
|
|
431
|
+
auto dot_qo = rabitq::bitwise_and_dot_product<SL>(
|
|
432
|
+
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
433
|
+
// It was a willful decision (after the discussion) to not to
|
|
434
|
+
// pre-cache the sum of all bits, just in order to reduce the
|
|
435
|
+
// overhead per vector.
|
|
436
|
+
// process 64-bit popcounts
|
|
437
|
+
auto sum_q = rabitq::popcount<SL>(binary_data, size);
|
|
438
|
+
// dot-product itself
|
|
439
|
+
final_dot += query_fac.c1 * dot_qo;
|
|
440
|
+
// normalizer coefficients
|
|
441
|
+
final_dot += query_fac.c2 * sum_q;
|
|
442
|
+
// normalizer coefficients
|
|
443
|
+
final_dot -= query_fac.c34;
|
|
444
|
+
}
|
|
424
445
|
|
|
425
|
-
float
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
(metric_type == MetricType::METRIC_L2 ||
|
|
429
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
446
|
+
const float pre_dist = base_fac->or_minus_c_l2sqr +
|
|
447
|
+
query_fac.qr_to_c_L2sqr -
|
|
448
|
+
2 * base_fac->dp_multiplier * final_dot;
|
|
430
449
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
size_t ex_bits = nb_bits - 1;
|
|
440
|
-
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
441
|
-
? reinterpret_cast<const SignBitFactors*>(code + size)
|
|
442
|
-
: reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
443
|
-
|
|
444
|
-
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
445
|
-
float final_dot = 0;
|
|
446
|
-
if (centered) {
|
|
447
|
-
int64_t int_dot = ((1 << qb) - 1) * d;
|
|
448
|
-
// See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
|
|
449
|
-
int_dot -= 2 *
|
|
450
|
-
rabitq::bitwise_xor_dot_product(
|
|
451
|
-
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
452
|
-
final_dot += int_dot * query_fac.int_dot_scale;
|
|
453
|
-
} else {
|
|
454
|
-
auto dot_qo = rabitq::bitwise_and_dot_product(
|
|
455
|
-
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
456
|
-
// It was a willful decision (after the discussion) to not to pre-cache
|
|
457
|
-
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
458
|
-
// process 64-bit popcounts
|
|
459
|
-
auto sum_q = rabitq::popcount(binary_data, size);
|
|
460
|
-
// dot-product itself
|
|
461
|
-
final_dot += query_fac.c1 * dot_qo;
|
|
462
|
-
// normalizer coefficients
|
|
463
|
-
final_dot += query_fac.c2 * sum_q;
|
|
464
|
-
// normalizer coefficients
|
|
465
|
-
final_dot -= query_fac.c34;
|
|
450
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
451
|
+
// ||or - q||^ 2
|
|
452
|
+
return pre_dist;
|
|
453
|
+
} else {
|
|
454
|
+
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
455
|
+
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
456
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
457
|
+
}
|
|
466
458
|
}
|
|
467
459
|
|
|
468
|
-
//
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
return pre_dist;
|
|
476
|
-
} else {
|
|
477
|
-
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
478
|
-
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
479
|
-
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
480
|
-
}
|
|
481
|
-
}
|
|
460
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
461
|
+
float distance_to_code_full(const uint8_t* code) override {
|
|
462
|
+
FAISS_ASSERT(code != nullptr);
|
|
463
|
+
FAISS_ASSERT(
|
|
464
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
465
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
466
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
482
467
|
|
|
483
|
-
|
|
484
|
-
FAISS_ASSERT(code != nullptr);
|
|
485
|
-
FAISS_ASSERT(
|
|
486
|
-
(metric_type == MetricType::METRIC_L2 ||
|
|
487
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
488
|
-
FAISS_ASSERT(rotated_q.size() == d);
|
|
468
|
+
size_t ex_bits = nb_bits - 1;
|
|
489
469
|
|
|
490
|
-
|
|
470
|
+
if (ex_bits == 0) {
|
|
471
|
+
// No ex-bits, just return 1-bit distance
|
|
472
|
+
return distance_to_code_1bit(code);
|
|
473
|
+
}
|
|
491
474
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
475
|
+
// Extract pointers to code sections
|
|
476
|
+
const uint8_t* binary_data = code;
|
|
477
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
478
|
+
const uint8_t* ex_code = code + offset;
|
|
479
|
+
const ExtraBitsFactors* ex_fac =
|
|
480
|
+
reinterpret_cast<const ExtraBitsFactors*>(
|
|
481
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
482
|
+
|
|
483
|
+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
484
|
+
? query_fac.q_dot_c
|
|
485
|
+
: query_fac.qr_to_c_L2sqr;
|
|
486
|
+
return rabitq_utils::compute_full_multibit_distance<SL>(
|
|
487
|
+
binary_data,
|
|
488
|
+
ex_code,
|
|
489
|
+
*ex_fac,
|
|
490
|
+
rotated_q.data(),
|
|
491
|
+
qr_base,
|
|
492
|
+
d,
|
|
493
|
+
ex_bits,
|
|
494
|
+
metric_type);
|
|
495
495
|
}
|
|
496
496
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
497
|
+
void set_query(const float* x) override {
|
|
498
|
+
q = x;
|
|
499
|
+
FAISS_ASSERT(x != nullptr);
|
|
500
|
+
FAISS_ASSERT(
|
|
501
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
502
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
503
|
+
FAISS_THROW_IF_NOT(qb <= 8);
|
|
504
|
+
FAISS_THROW_IF_NOT(qb > 0);
|
|
505
|
+
|
|
506
|
+
// Use shared utilities for core query factor computation
|
|
507
|
+
// rotated_q is populated directly by compute_query_factors as an
|
|
508
|
+
// output parameter
|
|
509
|
+
query_fac = rabitq_utils::compute_query_factors(
|
|
510
|
+
x,
|
|
511
|
+
d,
|
|
512
|
+
centroid,
|
|
513
|
+
qb,
|
|
514
|
+
centered,
|
|
515
|
+
metric_type,
|
|
516
|
+
rotated_q,
|
|
517
|
+
rotated_qq);
|
|
518
|
+
|
|
519
|
+
// Compute g_error (query norm for lower bound computation)
|
|
520
|
+
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
521
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
522
|
+
|
|
523
|
+
// Rearrange the query vector for SIMD operations
|
|
524
|
+
// (RaBitQuantizer-specific)
|
|
525
|
+
popcount_aligned_dim = ((d + 7) / 8) * 8;
|
|
526
|
+
size_t offset = (d + 7) / 8;
|
|
527
|
+
|
|
528
|
+
rearranged_rotated_qq.resize(offset * qb);
|
|
529
|
+
std::fill(
|
|
530
|
+
rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
|
|
531
|
+
|
|
532
|
+
for (size_t idim = 0; idim < d; idim++) {
|
|
533
|
+
for (size_t iv = 0; iv < qb; iv++) {
|
|
534
|
+
const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
|
|
535
|
+
rearranged_rotated_qq[iv * offset + idim / 8] |=
|
|
536
|
+
bit ? (1 << (idim % 8)) : 0;
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
};
|
|
516
541
|
|
|
517
542
|
// Use shared constant from RaBitQUtils
|
|
518
543
|
using rabitq_utils::Z_MAX_BY_QB;
|
|
519
544
|
|
|
520
|
-
void RaBitQDistanceComputerQ::set_query(const float* x) {
|
|
521
|
-
q = x;
|
|
522
|
-
FAISS_ASSERT(x != nullptr);
|
|
523
|
-
FAISS_ASSERT(
|
|
524
|
-
(metric_type == MetricType::METRIC_L2 ||
|
|
525
|
-
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
526
|
-
FAISS_THROW_IF_NOT(qb <= 8);
|
|
527
|
-
FAISS_THROW_IF_NOT(qb > 0);
|
|
528
|
-
|
|
529
|
-
// Use shared utilities for core query factor computation
|
|
530
|
-
// rotated_q is populated directly by compute_query_factors as an output
|
|
531
|
-
// parameter
|
|
532
|
-
query_fac = rabitq_utils::compute_query_factors(
|
|
533
|
-
x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
|
|
534
|
-
|
|
535
|
-
// Compute g_error (query norm for lower bound computation)
|
|
536
|
-
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
537
|
-
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
538
|
-
|
|
539
|
-
// Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
|
|
540
|
-
popcount_aligned_dim = ((d + 7) / 8) * 8;
|
|
541
|
-
size_t offset = (d + 7) / 8;
|
|
542
|
-
|
|
543
|
-
rearranged_rotated_qq.resize(offset * qb);
|
|
544
|
-
std::fill(rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
|
|
545
|
-
|
|
546
|
-
for (size_t idim = 0; idim < d; idim++) {
|
|
547
|
-
for (size_t iv = 0; iv < qb; iv++) {
|
|
548
|
-
const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
|
|
549
|
-
rearranged_rotated_qq[iv * offset + idim / 8] |=
|
|
550
|
-
bit ? (1 << (idim % 8)) : 0;
|
|
551
|
-
}
|
|
552
|
-
}
|
|
553
|
-
}
|
|
554
|
-
|
|
555
545
|
} // anonymous namespace
|
|
556
546
|
|
|
557
547
|
FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
|
|
558
548
|
uint8_t qb,
|
|
559
549
|
const float* centroid_in,
|
|
560
550
|
bool centered) const {
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
551
|
+
// Dispatch on SIMDLevel once here so the distance computer methods
|
|
552
|
+
// call the SIMD-specialized rabitq functions directly (no per-call
|
|
553
|
+
// with_simd_level overhead).
|
|
554
|
+
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
|
|
555
|
+
[&]<SIMDLevel SL>() -> FlatCodesDistanceComputer* {
|
|
556
|
+
if (qb == 0) {
|
|
557
|
+
auto dc =
|
|
558
|
+
std::make_unique<RaBitQDistanceComputerNotQ<SL>>();
|
|
559
|
+
dc->metric_type = metric_type;
|
|
560
|
+
dc->d = d;
|
|
561
|
+
dc->centroid = centroid_in;
|
|
562
|
+
dc->nb_bits = nb_bits;
|
|
563
|
+
|
|
564
|
+
return dc.release();
|
|
565
|
+
} else {
|
|
566
|
+
auto dc = std::make_unique<RaBitQDistanceComputerQ<SL>>();
|
|
567
|
+
dc->metric_type = metric_type;
|
|
568
|
+
dc->d = d;
|
|
569
|
+
dc->centroid = centroid_in;
|
|
570
|
+
dc->qb = qb;
|
|
571
|
+
dc->centered = centered;
|
|
572
|
+
dc->nb_bits = nb_bits;
|
|
573
|
+
|
|
574
|
+
return dc.release();
|
|
575
|
+
}
|
|
576
|
+
});
|
|
580
577
|
}
|
|
581
578
|
|
|
582
579
|
} // namespace faiss
|