faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -12,12 +12,11 @@
|
|
|
12
12
|
|
|
13
13
|
#include <faiss/IndexFastScan.h>
|
|
14
14
|
#include <faiss/IndexRaBitQ.h>
|
|
15
|
-
#include <faiss/impl/RaBitQStats.h>
|
|
16
15
|
#include <faiss/impl/RaBitQUtils.h>
|
|
17
16
|
#include <faiss/impl/RaBitQuantizer.h>
|
|
18
|
-
#include <faiss/impl/simd_result_handlers.h>
|
|
17
|
+
#include <faiss/impl/fast_scan/simd_result_handlers.h>
|
|
18
|
+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
|
|
19
19
|
#include <faiss/utils/Heap.h>
|
|
20
|
-
#include <faiss/utils/simdlib.h>
|
|
21
20
|
|
|
22
21
|
namespace faiss {
|
|
23
22
|
|
|
@@ -78,6 +77,10 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
78
77
|
|
|
79
78
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
80
79
|
|
|
80
|
+
/// Packed code size: (d + 7) / 8 bytes (1-bit-per-dimension sign bits,
|
|
81
|
+
/// excluding factors)
|
|
82
|
+
size_t fast_scan_code_size() const override;
|
|
83
|
+
|
|
81
84
|
/// Return CodePackerRaBitQ with enlarged block size
|
|
82
85
|
CodePacker* get_CodePacker() const override;
|
|
83
86
|
|
|
@@ -92,17 +95,17 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
92
95
|
idx_t* labels,
|
|
93
96
|
const SearchParameters* params = nullptr) const override;
|
|
94
97
|
|
|
95
|
-
///
|
|
96
|
-
|
|
98
|
+
/// RaBitQ scanner wired through rabitq_make_knn_scanner
|
|
99
|
+
std::unique_ptr<FastScanCodeScanner> make_knn_scanner(
|
|
97
100
|
bool is_max,
|
|
98
|
-
int /*impl*/,
|
|
99
101
|
idx_t n,
|
|
100
102
|
idx_t k,
|
|
101
|
-
size_t
|
|
103
|
+
size_t ntotal,
|
|
102
104
|
float* distances,
|
|
103
105
|
idx_t* labels,
|
|
104
106
|
const IDSelector* sel,
|
|
105
|
-
|
|
107
|
+
int impl = 0,
|
|
108
|
+
const FastScanDistancePostProcessing& context = {}) const override;
|
|
106
109
|
};
|
|
107
110
|
|
|
108
111
|
/** SIMD result handler for RaBitQ FastScan that applies distance corrections
|
|
@@ -122,18 +125,24 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
122
125
|
*
|
|
123
126
|
* @tparam C Comparator type (CMin/CMax) for heap operations
|
|
124
127
|
* @tparam with_id_map Whether to use id mapping (similar to HeapHandler)
|
|
128
|
+
* @tparam SL SIMD level for per-TU template instantiation
|
|
125
129
|
*/
|
|
126
|
-
template <
|
|
130
|
+
template <
|
|
131
|
+
class C,
|
|
132
|
+
bool with_id_map = false,
|
|
133
|
+
SIMDLevel SL = SINGLE_SIMD_LEVEL_256>
|
|
127
134
|
struct RaBitQHeapHandler
|
|
128
|
-
: simd_result_handlers::ResultHandlerCompare<C, with_id_map> {
|
|
129
|
-
using RHC = simd_result_handlers::ResultHandlerCompare<C, with_id_map>;
|
|
135
|
+
: simd_result_handlers::ResultHandlerCompare<C, with_id_map, SL> {
|
|
136
|
+
using RHC = simd_result_handlers::ResultHandlerCompare<C, with_id_map, SL>;
|
|
130
137
|
using RHC::normalizers;
|
|
138
|
+
static constexpr SIMDLevel SL256 = simd256_level_selector<SL>::value;
|
|
139
|
+
using simd16uint16 = simd16uint16_tpl<SL256>;
|
|
131
140
|
|
|
132
141
|
const IndexRaBitQFastScan* rabitq_index;
|
|
133
142
|
float* heap_distances; // [nq * k]
|
|
134
143
|
int64_t* heap_labels; // [nq * k]
|
|
135
144
|
const size_t nq, k;
|
|
136
|
-
const FastScanDistancePostProcessing
|
|
145
|
+
const FastScanDistancePostProcessing*
|
|
137
146
|
context; // Processing context with query offset
|
|
138
147
|
const bool is_multi_bit; // Runtime flag for multi-bit mode
|
|
139
148
|
|
|
@@ -141,7 +150,7 @@ struct RaBitQHeapHandler
|
|
|
141
150
|
const size_t storage_size;
|
|
142
151
|
const size_t packed_block_size;
|
|
143
152
|
const size_t full_block_size;
|
|
144
|
-
std::
|
|
153
|
+
std::vector<uint8_t> unpack_buf; // sign bits scratch buffer
|
|
145
154
|
|
|
146
155
|
// Use float-based comparator for heap operations
|
|
147
156
|
using Cfloat = typename std::conditional<
|
|
@@ -156,22 +165,169 @@ struct RaBitQHeapHandler
|
|
|
156
165
|
float* distances,
|
|
157
166
|
int64_t* labels,
|
|
158
167
|
const IDSelector* sel_in,
|
|
159
|
-
const FastScanDistancePostProcessing
|
|
160
|
-
bool multi_bit)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
168
|
+
const FastScanDistancePostProcessing* ctx,
|
|
169
|
+
bool multi_bit)
|
|
170
|
+
: RHC(nq_val, index->ntotal, sel_in),
|
|
171
|
+
rabitq_index(index),
|
|
172
|
+
heap_distances(distances),
|
|
173
|
+
heap_labels(labels),
|
|
174
|
+
nq(nq_val),
|
|
175
|
+
k(k_val),
|
|
176
|
+
context(ctx),
|
|
177
|
+
is_multi_bit(multi_bit),
|
|
178
|
+
storage_size(index->compute_per_vector_storage_size()),
|
|
179
|
+
packed_block_size(((index->M2 + 1) / 2) * index->bbs),
|
|
180
|
+
full_block_size(index->get_block_stride()),
|
|
181
|
+
unpack_buf((index->d + 7) / 8) {
|
|
182
|
+
#pragma omp parallel for if (nq > 100)
|
|
183
|
+
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
184
|
+
float* heap_dis = heap_distances + q * k;
|
|
185
|
+
int64_t* heap_ids = heap_labels + q * k;
|
|
186
|
+
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
191
|
+
ALIGNED(32) uint16_t d32tab[32];
|
|
192
|
+
d0.store(d32tab);
|
|
193
|
+
d1.store(d32tab + 16);
|
|
194
|
+
|
|
195
|
+
float* const heap_dis = heap_distances + q * k;
|
|
196
|
+
int64_t* const heap_ids = heap_labels + q * k;
|
|
197
|
+
|
|
198
|
+
rabitq_utils::QueryFactorsData query_factors_data = {};
|
|
199
|
+
if (context && context->query_factors != nullptr) {
|
|
200
|
+
query_factors_data = context->query_factors[q];
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
const float one_a = normalizers ? (1.0f / normalizers[2 * q]) : 1.0f;
|
|
204
|
+
const float bias = normalizers ? normalizers[2 * q + 1] : 0.0f;
|
|
205
|
+
|
|
206
|
+
const size_t base_db_idx = this->j0 + b * 32;
|
|
207
|
+
const size_t max_vectors = (base_db_idx < rabitq_index->ntotal)
|
|
208
|
+
? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
|
|
209
|
+
: 0;
|
|
210
|
+
|
|
211
|
+
const size_t block_idx = base_db_idx / rabitq_index->bbs;
|
|
212
|
+
const uint8_t* aux_base = rabitq_index->codes.get() +
|
|
213
|
+
block_idx * full_block_size + packed_block_size;
|
|
214
|
+
|
|
215
|
+
for (size_t i = 0; i < max_vectors; i++) {
|
|
216
|
+
const size_t db_idx = base_db_idx + i;
|
|
217
|
+
const float normalized_distance = d32tab[i] * one_a + bias;
|
|
218
|
+
const uint8_t* base_ptr = aux_base + i * storage_size;
|
|
219
|
+
|
|
220
|
+
if (is_multi_bit) {
|
|
221
|
+
const SignBitFactorsWithError& full_factors =
|
|
222
|
+
*reinterpret_cast<const SignBitFactorsWithError*>(
|
|
223
|
+
base_ptr);
|
|
224
|
+
|
|
225
|
+
float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
|
|
226
|
+
normalized_distance,
|
|
227
|
+
full_factors,
|
|
228
|
+
query_factors_data,
|
|
229
|
+
rabitq_index->centered,
|
|
230
|
+
rabitq_index->qb,
|
|
231
|
+
rabitq_index->d);
|
|
232
|
+
|
|
233
|
+
const bool is_similarity = rabitq_index->metric_type ==
|
|
234
|
+
MetricType::METRIC_INNER_PRODUCT;
|
|
235
|
+
bool should_refine = rabitq_utils::should_refine_candidate(
|
|
236
|
+
dist_1bit,
|
|
237
|
+
full_factors.f_error,
|
|
238
|
+
context && context->query_factors
|
|
239
|
+
? context->query_factors[q].g_error
|
|
240
|
+
: 0.0f,
|
|
241
|
+
heap_dis[0],
|
|
242
|
+
is_similarity);
|
|
243
|
+
|
|
244
|
+
if (should_refine) {
|
|
245
|
+
float dist_full = compute_full_multibit_distance(db_idx, q);
|
|
246
|
+
|
|
247
|
+
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
248
|
+
heap_replace_top<Cfloat>(
|
|
249
|
+
k, heap_dis, heap_ids, dist_full, db_idx);
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
} else {
|
|
253
|
+
const rabitq_utils::SignBitFactors& db_factors =
|
|
254
|
+
*reinterpret_cast<const rabitq_utils::SignBitFactors*>(
|
|
255
|
+
base_ptr);
|
|
256
|
+
|
|
257
|
+
float adjusted_distance =
|
|
258
|
+
rabitq_utils::compute_1bit_adjusted_distance(
|
|
259
|
+
normalized_distance,
|
|
260
|
+
db_factors,
|
|
261
|
+
query_factors_data,
|
|
262
|
+
rabitq_index->centered,
|
|
263
|
+
rabitq_index->qb,
|
|
264
|
+
rabitq_index->d);
|
|
265
|
+
|
|
266
|
+
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
267
|
+
heap_replace_top<Cfloat>(
|
|
268
|
+
k, heap_dis, heap_ids, adjusted_distance, db_idx);
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
void begin(const float* norms) override {
|
|
275
|
+
normalizers = norms;
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
void end() override {
|
|
279
|
+
#pragma omp parallel for if (nq > 100)
|
|
280
|
+
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
281
|
+
float* heap_dis = heap_distances + q * k;
|
|
282
|
+
int64_t* heap_ids = heap_labels + q * k;
|
|
283
|
+
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
|
|
284
|
+
}
|
|
285
|
+
}
|
|
167
286
|
|
|
168
287
|
private:
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
288
|
+
float compute_full_multibit_distance(size_t db_idx, size_t q) {
|
|
289
|
+
const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
|
|
290
|
+
const size_t dim = rabitq_index->d;
|
|
291
|
+
|
|
292
|
+
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
293
|
+
rabitq_index->codes.get(),
|
|
294
|
+
db_idx,
|
|
295
|
+
rabitq_index->bbs,
|
|
296
|
+
packed_block_size,
|
|
297
|
+
full_block_size,
|
|
298
|
+
storage_size);
|
|
299
|
+
|
|
300
|
+
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
301
|
+
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
302
|
+
const ExtraBitsFactors& ex_fac =
|
|
303
|
+
*reinterpret_cast<const ExtraBitsFactors*>(
|
|
304
|
+
base_ptr + sizeof(SignBitFactorsWithError) +
|
|
305
|
+
ex_code_size);
|
|
306
|
+
|
|
307
|
+
const rabitq_utils::QueryFactorsData& query_factors =
|
|
308
|
+
context->query_factors[q];
|
|
309
|
+
|
|
310
|
+
rabitq_utils::unpack_sign_bits_from_packed(
|
|
311
|
+
rabitq_index->codes.get(),
|
|
312
|
+
rabitq_index->bbs,
|
|
313
|
+
rabitq_index->M2,
|
|
314
|
+
db_idx,
|
|
315
|
+
full_block_size,
|
|
316
|
+
unpack_buf.data());
|
|
317
|
+
const uint8_t* sign_bits = unpack_buf.data();
|
|
318
|
+
|
|
319
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
320
|
+
sign_bits,
|
|
321
|
+
ex_code,
|
|
322
|
+
ex_fac,
|
|
323
|
+
query_factors.rotated_q.data(),
|
|
324
|
+
(rabitq_index->metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
325
|
+
? query_factors.q_dot_c
|
|
326
|
+
: query_factors.qr_to_c_L2sqr,
|
|
327
|
+
dim,
|
|
328
|
+
ex_bits,
|
|
329
|
+
rabitq_index->metric_type);
|
|
330
|
+
}
|
|
175
331
|
};
|
|
176
332
|
|
|
177
333
|
} // namespace faiss
|
|
@@ -18,10 +18,10 @@ namespace faiss {
|
|
|
18
18
|
* IndexRefine
|
|
19
19
|
***************************************************/
|
|
20
20
|
|
|
21
|
-
IndexRefine::IndexRefine(Index*
|
|
22
|
-
: Index(
|
|
23
|
-
base_index(
|
|
24
|
-
refine_index(
|
|
21
|
+
IndexRefine::IndexRefine(Index* base_index_in, Index* refine_index_in)
|
|
22
|
+
: Index(base_index_in->d, base_index_in->metric_type),
|
|
23
|
+
base_index(base_index_in),
|
|
24
|
+
refine_index(refine_index_in) {
|
|
25
25
|
own_fields = own_refine_index = false;
|
|
26
26
|
if (refine_index != nullptr) {
|
|
27
27
|
FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
|
|
@@ -84,6 +84,8 @@ void IndexRefine::search(
|
|
|
84
84
|
|
|
85
85
|
FAISS_THROW_IF_NOT(k > 0);
|
|
86
86
|
FAISS_THROW_IF_NOT(is_trained);
|
|
87
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
88
|
+
n <= INT64_MAX / k_base, "n * k_base would overflow int64");
|
|
87
89
|
idx_t* base_labels = labels;
|
|
88
90
|
float* base_distances = distances;
|
|
89
91
|
std::unique_ptr<idx_t[]> del1;
|
|
@@ -99,8 +101,8 @@ void IndexRefine::search(
|
|
|
99
101
|
base_index->search(
|
|
100
102
|
n, x, k_base, base_distances, base_labels, base_index_params);
|
|
101
103
|
|
|
102
|
-
for (
|
|
103
|
-
|
|
104
|
+
for (idx_t i = 0; i < n * k_base; i++) {
|
|
105
|
+
FAISS_THROW_IF_NOT(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
104
106
|
}
|
|
105
107
|
|
|
106
108
|
// parallelize over queries
|
|
@@ -125,12 +127,12 @@ void IndexRefine::search(
|
|
|
125
127
|
|
|
126
128
|
// sort and store result
|
|
127
129
|
if (metric_type == METRIC_L2) {
|
|
128
|
-
|
|
130
|
+
using C = CMax<float, idx_t>;
|
|
129
131
|
reorder_2_heaps<C>(
|
|
130
132
|
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
131
133
|
|
|
132
134
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
133
|
-
|
|
135
|
+
using C = CMin<float, idx_t>;
|
|
134
136
|
reorder_2_heaps<C>(
|
|
135
137
|
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
136
138
|
} else {
|
|
@@ -191,7 +193,7 @@ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
|
|
191
193
|
base_index->sa_encode(n, x, tmp1.get());
|
|
192
194
|
std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
|
|
193
195
|
refine_index->sa_encode(n, x, tmp2.get());
|
|
194
|
-
for (
|
|
196
|
+
for (idx_t i = 0; i < n; i++) {
|
|
195
197
|
uint8_t* b = bytes + i * (cs1 + cs2);
|
|
196
198
|
memcpy(b, tmp1.get() + cs1 * i, cs1);
|
|
197
199
|
memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
|
|
@@ -200,10 +202,9 @@ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
|
|
200
202
|
|
|
201
203
|
void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
202
204
|
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
|
|
203
|
-
std::unique_ptr<uint8_t[]> tmp2(
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
|
|
205
|
+
std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
|
|
206
|
+
for (idx_t i = 0; i < n; i++) {
|
|
207
|
+
memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2) + cs1, cs2);
|
|
207
208
|
}
|
|
208
209
|
|
|
209
210
|
refine_index->sa_decode(n, tmp2.get(), x);
|
|
@@ -222,10 +223,10 @@ IndexRefine::~IndexRefine() {
|
|
|
222
223
|
* IndexRefineFlat
|
|
223
224
|
***************************************************/
|
|
224
225
|
|
|
225
|
-
IndexRefineFlat::IndexRefineFlat(Index*
|
|
226
|
+
IndexRefineFlat::IndexRefineFlat(Index* base_index_in)
|
|
226
227
|
: IndexRefine(
|
|
227
|
-
|
|
228
|
-
new IndexFlat(
|
|
228
|
+
base_index_in,
|
|
229
|
+
new IndexFlat(base_index_in->d, base_index_in->metric_type)) {
|
|
229
230
|
is_trained = base_index->is_trained;
|
|
230
231
|
own_refine_index = true;
|
|
231
232
|
FAISS_THROW_IF_NOT_MSG(
|
|
@@ -233,8 +234,8 @@ IndexRefineFlat::IndexRefineFlat(Index* base_index)
|
|
|
233
234
|
"base_index should be empty in the beginning");
|
|
234
235
|
}
|
|
235
236
|
|
|
236
|
-
IndexRefineFlat::IndexRefineFlat(Index*
|
|
237
|
-
: IndexRefine(
|
|
237
|
+
IndexRefineFlat::IndexRefineFlat(Index* base_index_in, const float* xb)
|
|
238
|
+
: IndexRefine(base_index_in, nullptr) {
|
|
238
239
|
is_trained = base_index->is_trained;
|
|
239
240
|
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
|
|
240
241
|
own_refine_index = true;
|
|
@@ -271,6 +272,8 @@ void IndexRefineFlat::search(
|
|
|
271
272
|
|
|
272
273
|
FAISS_THROW_IF_NOT(k > 0);
|
|
273
274
|
FAISS_THROW_IF_NOT(is_trained);
|
|
275
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
276
|
+
n <= INT64_MAX / k_base, "n * k_base would overflow int64");
|
|
274
277
|
idx_t* base_labels = labels;
|
|
275
278
|
float* base_distances = distances;
|
|
276
279
|
std::unique_ptr<idx_t[]> del1;
|
|
@@ -286,8 +289,8 @@ void IndexRefineFlat::search(
|
|
|
286
289
|
base_index->search(
|
|
287
290
|
n, x, k_base, base_distances, base_labels, base_index_params);
|
|
288
291
|
|
|
289
|
-
for (
|
|
290
|
-
|
|
292
|
+
for (idx_t i = 0; i < n * k_base; i++) {
|
|
293
|
+
FAISS_THROW_IF_NOT(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
291
294
|
}
|
|
292
295
|
|
|
293
296
|
// compute refined distances
|
|
@@ -298,12 +301,12 @@ void IndexRefineFlat::search(
|
|
|
298
301
|
|
|
299
302
|
// sort and store result
|
|
300
303
|
if (metric_type == METRIC_L2) {
|
|
301
|
-
|
|
304
|
+
using C = CMax<float, idx_t>;
|
|
302
305
|
reorder_2_heaps<C>(
|
|
303
306
|
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
304
307
|
|
|
305
308
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
306
|
-
|
|
309
|
+
using C = CMin<float, idx_t>;
|
|
307
310
|
reorder_2_heaps<C>(
|
|
308
311
|
n, k, labels, distances, k_base, base_labels, base_distances);
|
|
309
312
|
} else {
|
|
@@ -326,7 +329,7 @@ void IndexRefinePanorama::search(
|
|
|
326
329
|
if (params_in) {
|
|
327
330
|
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
|
|
328
331
|
FAISS_THROW_IF_NOT_MSG(
|
|
329
|
-
params, "
|
|
332
|
+
params, "IndexRefinePanorama params have incorrect type");
|
|
330
333
|
}
|
|
331
334
|
|
|
332
335
|
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
|
|
@@ -341,6 +344,8 @@ void IndexRefinePanorama::search(
|
|
|
341
344
|
|
|
342
345
|
FAISS_THROW_IF_NOT(k > 0);
|
|
343
346
|
FAISS_THROW_IF_NOT(is_trained);
|
|
347
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
348
|
+
n <= INT64_MAX / k_base, "n * k_base would overflow int64");
|
|
344
349
|
|
|
345
350
|
std::unique_ptr<idx_t[]> del1;
|
|
346
351
|
std::unique_ptr<float[]> del2;
|
|
@@ -352,8 +357,8 @@ void IndexRefinePanorama::search(
|
|
|
352
357
|
base_index->search(
|
|
353
358
|
n, x, k_base, base_distances, base_labels, base_index_params);
|
|
354
359
|
|
|
355
|
-
for (
|
|
356
|
-
|
|
360
|
+
for (idx_t i = 0; i < n * k_base; i++) {
|
|
361
|
+
FAISS_THROW_IF_NOT(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
357
362
|
}
|
|
358
363
|
|
|
359
364
|
refine_index->search_subset(
|
|
@@ -28,8 +28,8 @@ struct IndexRefine : Index {
|
|
|
28
28
|
/// refinement index
|
|
29
29
|
Index* refine_index;
|
|
30
30
|
|
|
31
|
-
bool own_fields; ///< should the base index be deallocated?
|
|
32
|
-
bool own_refine_index; ///< same with the refinement index
|
|
31
|
+
bool own_fields = false; ///< should the base index be deallocated?
|
|
32
|
+
bool own_refine_index = false; ///< same with the refinement index
|
|
33
33
|
|
|
34
34
|
/// factor between k requested in search and the k requested from
|
|
35
35
|
/// the base_index (should be >= 1)
|
|
@@ -98,8 +98,8 @@ struct IndexRefineFlat : IndexRefine {
|
|
|
98
98
|
/** Version where the search calls search_subset, allowing for Panorama
|
|
99
99
|
* refinement. */
|
|
100
100
|
struct IndexRefinePanorama : IndexRefine {
|
|
101
|
-
explicit IndexRefinePanorama(Index*
|
|
102
|
-
: IndexRefine(
|
|
101
|
+
explicit IndexRefinePanorama(Index* base_index_in, Index* refine_index_in)
|
|
102
|
+
: IndexRefine(base_index_in, refine_index_in) {}
|
|
103
103
|
|
|
104
104
|
IndexRefinePanorama() : IndexRefine() {}
|
|
105
105
|
|
|
@@ -16,7 +16,7 @@ namespace {
|
|
|
16
16
|
|
|
17
17
|
// IndexBinary needs to update the code_size when d is set...
|
|
18
18
|
|
|
19
|
-
void sync_d(Index* index) {}
|
|
19
|
+
void sync_d(Index* /*index*/) {}
|
|
20
20
|
|
|
21
21
|
void sync_d(IndexBinary* index) {
|
|
22
22
|
FAISS_THROW_IF_NOT(index->d % 8 == 0);
|
|
@@ -30,14 +30,14 @@ IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
|
|
|
30
30
|
: ThreadedIndex<IndexT>(threaded) {}
|
|
31
31
|
|
|
32
32
|
template <typename IndexT>
|
|
33
|
-
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t
|
|
34
|
-
: ThreadedIndex<IndexT>(
|
|
33
|
+
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d_in, bool threaded)
|
|
34
|
+
: ThreadedIndex<IndexT>(static_cast<int>(d_in), threaded) {
|
|
35
35
|
sync_d(this);
|
|
36
36
|
}
|
|
37
37
|
|
|
38
38
|
template <typename IndexT>
|
|
39
|
-
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int
|
|
40
|
-
: ThreadedIndex<IndexT>(
|
|
39
|
+
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d_in, bool threaded)
|
|
40
|
+
: ThreadedIndex<IndexT>(d_in, threaded) {
|
|
41
41
|
sync_d(this);
|
|
42
42
|
}
|
|
43
43
|
|
|
@@ -71,7 +71,7 @@ void IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
|
|
|
71
71
|
}
|
|
72
72
|
|
|
73
73
|
template <typename IndexT>
|
|
74
|
-
void IndexReplicasTemplate<IndexT>::onAfterRemoveIndex(IndexT* index) {
|
|
74
|
+
void IndexReplicasTemplate<IndexT>::onAfterRemoveIndex(IndexT* /*index*/) {
|
|
75
75
|
syncWithSubIndexes();
|
|
76
76
|
}
|
|
77
77
|
|
|
@@ -162,10 +162,10 @@ void sa_decode_impl(
|
|
|
162
162
|
|
|
163
163
|
// allocate tmp buffers
|
|
164
164
|
std::vector<uint8_t> tmp(
|
|
165
|
-
(chunk_size < n_input
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
165
|
+
(chunk_size < static_cast<size_t>(n_input)
|
|
166
|
+
? chunk_size
|
|
167
|
+
: static_cast<size_t>(n_input)) *
|
|
168
|
+
old_code_size);
|
|
169
169
|
// all the elements to process
|
|
170
170
|
size_t n_left = n_input;
|
|
171
171
|
|
|
@@ -226,7 +226,7 @@ void train_inplace_impl(
|
|
|
226
226
|
std::vector<StorageMinMaxT> minmax(n);
|
|
227
227
|
|
|
228
228
|
// normalize
|
|
229
|
-
#pragma omp for
|
|
229
|
+
#pragma omp parallel for
|
|
230
230
|
for (idx_t i = 0; i < n; i++) {
|
|
231
231
|
// compute min & max values
|
|
232
232
|
float minv = std::numeric_limits<float>::max();
|
|
@@ -264,6 +264,7 @@ void train_inplace_impl(
|
|
|
264
264
|
sub_index->train(n, x);
|
|
265
265
|
|
|
266
266
|
// rescale data back
|
|
267
|
+
#pragma omp parallel for
|
|
267
268
|
for (idx_t i = 0; i < n; i++) {
|
|
268
269
|
float scaler = 0;
|
|
269
270
|
float minv = 0;
|
|
@@ -289,7 +290,7 @@ void train_impl(IndexRowwiseMinMaxBase* const index, idx_t n, const float* x) {
|
|
|
289
290
|
// temp buffer
|
|
290
291
|
std::vector<float> tmp(n * d);
|
|
291
292
|
|
|
292
|
-
#pragma omp for
|
|
293
|
+
#pragma omp parallel for
|
|
293
294
|
for (idx_t i = 0; i < n; i++) {
|
|
294
295
|
// compute min & max values
|
|
295
296
|
float minv = std::numeric_limits<float>::max();
|
|
@@ -304,7 +305,7 @@ void train_impl(IndexRowwiseMinMaxBase* const index, idx_t n, const float* x) {
|
|
|
304
305
|
const float scaler = maxv - minv;
|
|
305
306
|
|
|
306
307
|
// save the coefficients
|
|
307
|
-
StorageMinMaxT storage;
|
|
308
|
+
StorageMinMaxT storage = {};
|
|
308
309
|
storage.from_floats(scaler, minv);
|
|
309
310
|
|
|
310
311
|
// and load them back, because the coefficients might
|
|
@@ -339,9 +340,9 @@ int rowwise_minmax_sa_decode_bs = 16384;
|
|
|
339
340
|
* IndexRowwiseMinMaxBase implementation
|
|
340
341
|
********************************************************/
|
|
341
342
|
|
|
342
|
-
IndexRowwiseMinMaxBase::IndexRowwiseMinMaxBase(Index*
|
|
343
|
-
: Index(
|
|
344
|
-
index{
|
|
343
|
+
IndexRowwiseMinMaxBase::IndexRowwiseMinMaxBase(Index* index_in)
|
|
344
|
+
: Index(index_in->d, index_in->metric_type),
|
|
345
|
+
index{index_in},
|
|
345
346
|
own_fields{false} {}
|
|
346
347
|
|
|
347
348
|
IndexRowwiseMinMaxBase::IndexRowwiseMinMaxBase()
|
|
@@ -376,8 +377,8 @@ void IndexRowwiseMinMaxBase::reset() {
|
|
|
376
377
|
* IndexRowwiseMinMaxFP16 implementation
|
|
377
378
|
********************************************************/
|
|
378
379
|
|
|
379
|
-
IndexRowwiseMinMaxFP16::IndexRowwiseMinMaxFP16(Index*
|
|
380
|
-
: IndexRowwiseMinMaxBase(
|
|
380
|
+
IndexRowwiseMinMaxFP16::IndexRowwiseMinMaxFP16(Index* index_in)
|
|
381
|
+
: IndexRowwiseMinMaxBase(index_in) {}
|
|
381
382
|
|
|
382
383
|
IndexRowwiseMinMaxFP16::IndexRowwiseMinMaxFP16() : IndexRowwiseMinMaxBase() {}
|
|
383
384
|
|
|
@@ -411,8 +412,8 @@ void IndexRowwiseMinMaxFP16::train_inplace(idx_t n, float* x) {
|
|
|
411
412
|
* IndexRowwiseMinMax implementation
|
|
412
413
|
********************************************************/
|
|
413
414
|
|
|
414
|
-
IndexRowwiseMinMax::IndexRowwiseMinMax(Index*
|
|
415
|
-
: IndexRowwiseMinMaxBase(
|
|
415
|
+
IndexRowwiseMinMax::IndexRowwiseMinMax(Index* index_in)
|
|
416
|
+
: IndexRowwiseMinMaxBase(index_in) {}
|
|
416
417
|
|
|
417
418
|
IndexRowwiseMinMax::IndexRowwiseMinMax() : IndexRowwiseMinMaxBase() {}
|
|
418
419
|
|