faiss 0.5.3 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +4 -4
- data/ext/faiss/index.cpp +63 -45
- data/ext/faiss/index_binary.cpp +37 -27
- data/ext/faiss/kmeans.cpp +9 -8
- data/ext/faiss/pca_matrix.cpp +9 -7
- data/ext/faiss/product_quantizer.cpp +13 -11
- data/ext/faiss/utils.cpp +4 -2
- data/ext/faiss/utils.h +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +214 -82
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +97 -249
- data/vendor/faiss/faiss/Clustering.h +18 -0
- data/vendor/faiss/faiss/IVFlib.cpp +67 -44
- data/vendor/faiss/faiss/Index.cpp +25 -12
- data/vendor/faiss/faiss/Index.h +26 -4
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
- data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
- data/vendor/faiss/faiss/IndexFastScan.h +35 -24
- data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
- data/vendor/faiss/faiss/IndexFlat.h +32 -14
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
- data/vendor/faiss/faiss/IndexHNSW.h +30 -14
- data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
- data/vendor/faiss/faiss/IndexIVF.h +47 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
- data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
- data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
- data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
- data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
- data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +13 -13
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +29 -6
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
- data/vendor/faiss/faiss/VectorTransform.h +39 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +55 -51
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
- data/vendor/faiss/faiss/impl/HNSW.h +21 -40
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
- data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
- data/vendor/faiss/faiss/impl/NSG.h +20 -10
- data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
- data/vendor/faiss/faiss/impl/Panorama.h +265 -78
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
- data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +115 -28
- data/vendor/faiss/faiss/index_io.h +53 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +507 -559
- data/vendor/faiss/faiss/utils/distances.h +118 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
- data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
- data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +21 -14
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +156 -42
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -13,21 +13,19 @@
|
|
|
13
13
|
#include <cstddef>
|
|
14
14
|
#include <cstdio>
|
|
15
15
|
#include <cstring>
|
|
16
|
+
#include <vector>
|
|
16
17
|
|
|
17
18
|
#include <omp.h>
|
|
18
19
|
|
|
19
|
-
#ifdef __AVX2__
|
|
20
|
-
#include <immintrin.h>
|
|
21
|
-
#elif defined(__ARM_FEATURE_SVE)
|
|
22
|
-
#include <arm_sve.h>
|
|
23
|
-
#endif
|
|
24
|
-
|
|
25
20
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
26
21
|
#include <faiss/impl/FaissAssert.h>
|
|
27
22
|
#include <faiss/impl/IDSelector.h>
|
|
28
23
|
#include <faiss/impl/ResultHandler.h>
|
|
29
24
|
|
|
25
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
26
|
+
#include <faiss/utils/distances_dispatch.h>
|
|
30
27
|
#include <faiss/utils/distances_fused/distances_fused.h>
|
|
28
|
+
#include <faiss/utils/simd_impl/exhaustive_L2sqr_blas_cmax.h>
|
|
31
29
|
|
|
32
30
|
#ifndef FINTEGER
|
|
33
31
|
#define FINTEGER long
|
|
@@ -55,6 +53,146 @@ int sgemm_(
|
|
|
55
53
|
|
|
56
54
|
namespace faiss {
|
|
57
55
|
|
|
56
|
+
/***************************************************************************
|
|
57
|
+
* Public API dispatch wrappers
|
|
58
|
+
***************************************************************************/
|
|
59
|
+
|
|
60
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
61
|
+
return fvec_L1_dispatch(x, y, d);
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
65
|
+
return fvec_Linf_dispatch(x, y, d);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
69
|
+
return fvec_norm_L2sqr_dispatch(x, d);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
73
|
+
return fvec_L2sqr_dispatch(x, y, d);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
77
|
+
return fvec_inner_product_dispatch(x, y, d);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
void fvec_inner_product_batch_4(
|
|
81
|
+
const float* x,
|
|
82
|
+
const float* y0,
|
|
83
|
+
const float* y1,
|
|
84
|
+
const float* y2,
|
|
85
|
+
const float* y3,
|
|
86
|
+
const size_t d,
|
|
87
|
+
float& dis0,
|
|
88
|
+
float& dis1,
|
|
89
|
+
float& dis2,
|
|
90
|
+
float& dis3) {
|
|
91
|
+
fvec_inner_product_batch_4_dispatch(
|
|
92
|
+
x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
void fvec_L2sqr_batch_4(
|
|
96
|
+
const float* x,
|
|
97
|
+
const float* y0,
|
|
98
|
+
const float* y1,
|
|
99
|
+
const float* y2,
|
|
100
|
+
const float* y3,
|
|
101
|
+
const size_t d,
|
|
102
|
+
float& dis0,
|
|
103
|
+
float& dis1,
|
|
104
|
+
float& dis2,
|
|
105
|
+
float& dis3) {
|
|
106
|
+
fvec_L2sqr_batch_4_dispatch(x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
void fvec_L2sqr_ny_transposed(
|
|
110
|
+
float* dis,
|
|
111
|
+
const float* x,
|
|
112
|
+
const float* y,
|
|
113
|
+
const float* y_sqlen,
|
|
114
|
+
size_t d,
|
|
115
|
+
size_t d_offset,
|
|
116
|
+
size_t ny) {
|
|
117
|
+
fvec_L2sqr_ny_transposed_dispatch(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
void fvec_inner_products_ny(
|
|
121
|
+
float* ip,
|
|
122
|
+
const float* x,
|
|
123
|
+
const float* y,
|
|
124
|
+
size_t d,
|
|
125
|
+
size_t ny) {
|
|
126
|
+
fvec_inner_products_ny_dispatch(ip, x, y, d, ny);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
void fvec_L2sqr_ny(
|
|
130
|
+
float* dis,
|
|
131
|
+
const float* x,
|
|
132
|
+
const float* y,
|
|
133
|
+
size_t d,
|
|
134
|
+
size_t ny) {
|
|
135
|
+
fvec_L2sqr_ny_dispatch(dis, x, y, d, ny);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
size_t fvec_L2sqr_ny_nearest(
|
|
139
|
+
float* distances_tmp_buffer,
|
|
140
|
+
const float* x,
|
|
141
|
+
const float* y,
|
|
142
|
+
size_t d,
|
|
143
|
+
size_t ny) {
|
|
144
|
+
return fvec_L2sqr_ny_nearest_dispatch(distances_tmp_buffer, x, y, d, ny);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
148
|
+
float* distances_tmp_buffer,
|
|
149
|
+
const float* x,
|
|
150
|
+
const float* y,
|
|
151
|
+
const float* y_sqlen,
|
|
152
|
+
size_t d,
|
|
153
|
+
size_t d_offset,
|
|
154
|
+
size_t ny) {
|
|
155
|
+
return fvec_L2sqr_ny_nearest_y_transposed_dispatch(
|
|
156
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
160
|
+
fvec_madd_dispatch(n, a, bf, b, c);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
int fvec_madd_and_argmin(
|
|
164
|
+
size_t n,
|
|
165
|
+
const float* a,
|
|
166
|
+
float bf,
|
|
167
|
+
const float* b,
|
|
168
|
+
float* c) {
|
|
169
|
+
return fvec_madd_and_argmin_dispatch(n, a, bf, b, c);
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
|
|
173
|
+
fvec_sub_dispatch(d, a, b, c);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
void fvec_add(size_t d, const float* a, const float* b, float* c) {
|
|
177
|
+
fvec_add_dispatch(d, a, b, c);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
void fvec_add(size_t d, const float* a, float b, float* c) {
|
|
181
|
+
fvec_add_scalar_dispatch(d, a, b, c);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
void compute_PQ_dis_tables_dsub2(
|
|
185
|
+
size_t d,
|
|
186
|
+
size_t ksub,
|
|
187
|
+
const float* all_centroids,
|
|
188
|
+
size_t nx,
|
|
189
|
+
const float* x,
|
|
190
|
+
bool is_inner_product,
|
|
191
|
+
float* dis_tables) {
|
|
192
|
+
compute_PQ_dis_tables_dsub2_dispatch(
|
|
193
|
+
d, ksub, all_centroids, nx, x, is_inner_product, dis_tables);
|
|
194
|
+
}
|
|
195
|
+
|
|
58
196
|
/***************************************************************************
|
|
59
197
|
* Matrix/vector ops
|
|
60
198
|
***************************************************************************/
|
|
@@ -65,10 +203,12 @@ void fvec_norms_L2(
|
|
|
65
203
|
const float* __restrict x,
|
|
66
204
|
size_t d,
|
|
67
205
|
size_t nx) {
|
|
206
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
68
207
|
#pragma omp parallel for if (nx > 10000)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
208
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
209
|
+
nr[i] = sqrtf(fvec_norm_L2sqr<SL>(x + i * d, d));
|
|
210
|
+
}
|
|
211
|
+
});
|
|
72
212
|
}
|
|
73
213
|
|
|
74
214
|
void fvec_norms_L2sqr(
|
|
@@ -76,10 +216,12 @@ void fvec_norms_L2sqr(
|
|
|
76
216
|
const float* __restrict x,
|
|
77
217
|
size_t d,
|
|
78
218
|
size_t nx) {
|
|
219
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
79
220
|
#pragma omp parallel for if (nx > 10000)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
221
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
222
|
+
nr[i] = fvec_norm_L2sqr<SL>(x + i * d, d);
|
|
223
|
+
}
|
|
224
|
+
});
|
|
83
225
|
}
|
|
84
226
|
|
|
85
227
|
// The following is a workaround to a problem
|
|
@@ -93,29 +235,35 @@ void fvec_norms_L2sqr(
|
|
|
93
235
|
// The workaround below is explicitly branching
|
|
94
236
|
// off to a codepath without omp.
|
|
95
237
|
|
|
96
|
-
#define FVEC_RENORM_L2_IMPL \
|
|
97
|
-
float* __restrict xi = x + i * d; \
|
|
98
|
-
\
|
|
99
|
-
float nr = fvec_norm_L2sqr(xi, d); \
|
|
100
|
-
\
|
|
101
|
-
if (nr > 0) { \
|
|
102
|
-
size_t j; \
|
|
103
|
-
const float inv_nr = 1.0 / sqrtf(nr); \
|
|
104
|
-
for (j = 0; j < d; j++) \
|
|
105
|
-
xi[j] *= inv_nr; \
|
|
106
|
-
}
|
|
107
|
-
|
|
108
238
|
void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
239
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
240
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
241
|
+
float* __restrict xi = x + i * d;
|
|
242
|
+
float nr = fvec_norm_L2sqr<SL>(xi, d);
|
|
243
|
+
if (nr > 0) {
|
|
244
|
+
const float inv_nr = 1.0 / sqrtf(nr);
|
|
245
|
+
for (size_t j = 0; j < d; j++) {
|
|
246
|
+
xi[j] *= inv_nr;
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
});
|
|
112
251
|
}
|
|
113
252
|
|
|
114
253
|
void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
|
|
254
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
115
255
|
#pragma omp parallel for if (nx > 10000)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
256
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
257
|
+
float* __restrict xi = x + i * d;
|
|
258
|
+
float nr = fvec_norm_L2sqr<SL>(xi, d);
|
|
259
|
+
if (nr > 0) {
|
|
260
|
+
const float inv_nr = 1.0 / sqrtf(nr);
|
|
261
|
+
for (size_t j = 0; j < d; j++) {
|
|
262
|
+
xi[j] *= inv_nr;
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
});
|
|
119
267
|
}
|
|
120
268
|
|
|
121
269
|
void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
@@ -148,22 +296,24 @@ void exhaustive_inner_product_seq(
|
|
|
148
296
|
#pragma omp parallel num_threads(nt)
|
|
149
297
|
{
|
|
150
298
|
SingleResultHandler resi(res);
|
|
299
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
151
300
|
#pragma omp for
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
301
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
302
|
+
const float* x_i = x + i * d;
|
|
303
|
+
const float* y_j = y;
|
|
155
304
|
|
|
156
|
-
|
|
305
|
+
resi.begin(i);
|
|
157
306
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
307
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
308
|
+
if (!res.is_in_selection(j)) {
|
|
309
|
+
continue;
|
|
310
|
+
}
|
|
311
|
+
float ip = fvec_inner_product<SL>(x_i, y_j, d);
|
|
312
|
+
resi.add_result(ip, j);
|
|
161
313
|
}
|
|
162
|
-
|
|
163
|
-
resi.add_result(ip, j);
|
|
314
|
+
resi.end();
|
|
164
315
|
}
|
|
165
|
-
|
|
166
|
-
}
|
|
316
|
+
});
|
|
167
317
|
}
|
|
168
318
|
}
|
|
169
319
|
|
|
@@ -182,20 +332,22 @@ void exhaustive_L2sqr_seq(
|
|
|
182
332
|
#pragma omp parallel num_threads(nt)
|
|
183
333
|
{
|
|
184
334
|
SingleResultHandler resi(res);
|
|
335
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
185
336
|
#pragma omp for
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
337
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
338
|
+
const float* x_i = x + i * d;
|
|
339
|
+
const float* y_j = y;
|
|
340
|
+
resi.begin(i);
|
|
341
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
342
|
+
if (!res.is_in_selection(j)) {
|
|
343
|
+
continue;
|
|
344
|
+
}
|
|
345
|
+
float disij = fvec_L2sqr<SL>(x_i, y_j, d);
|
|
346
|
+
resi.add_result(disij, j);
|
|
193
347
|
}
|
|
194
|
-
|
|
195
|
-
resi.add_result(disij, j);
|
|
348
|
+
resi.end();
|
|
196
349
|
}
|
|
197
|
-
|
|
198
|
-
}
|
|
350
|
+
});
|
|
199
351
|
}
|
|
200
352
|
}
|
|
201
353
|
|
|
@@ -321,7 +473,7 @@ void exhaustive_L2sqr_blas_default_impl(
|
|
|
321
473
|
ip_block.get(),
|
|
322
474
|
&nyi);
|
|
323
475
|
}
|
|
324
|
-
for (
|
|
476
|
+
for (size_t i = i0; i < i1; i++) {
|
|
325
477
|
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
326
478
|
|
|
327
479
|
for (size_t j = j0; j < j1; j++) {
|
|
@@ -357,396 +509,12 @@ void exhaustive_L2sqr_blas(
|
|
|
357
509
|
size_t ny,
|
|
358
510
|
BlockResultHandler& res,
|
|
359
511
|
const float* y_norms = nullptr) {
|
|
360
|
-
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
|
|
361
|
-
}
|
|
362
|
-
|
|
363
|
-
#ifdef __AVX2__
|
|
364
|
-
void exhaustive_L2sqr_blas_cmax_avx2(
|
|
365
|
-
const float* x,
|
|
366
|
-
const float* y,
|
|
367
|
-
size_t d,
|
|
368
|
-
size_t nx,
|
|
369
|
-
size_t ny,
|
|
370
|
-
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
371
|
-
const float* y_norms) {
|
|
372
|
-
// BLAS does not like empty matrices
|
|
373
|
-
if (nx == 0 || ny == 0) {
|
|
374
|
-
return;
|
|
375
|
-
}
|
|
376
|
-
|
|
377
|
-
/* block sizes */
|
|
378
|
-
const size_t bs_x = distance_compute_blas_query_bs;
|
|
379
|
-
const size_t bs_y = distance_compute_blas_database_bs;
|
|
380
|
-
// const size_t bs_x = 16, bs_y = 16;
|
|
381
|
-
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
382
|
-
std::unique_ptr<float[]> x_norms(new float[nx]);
|
|
383
|
-
std::unique_ptr<float[]> del2;
|
|
384
|
-
|
|
385
|
-
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
|
386
|
-
|
|
387
|
-
if (!y_norms) {
|
|
388
|
-
float* y_norms2 = new float[ny];
|
|
389
|
-
del2.reset(y_norms2);
|
|
390
|
-
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
|
391
|
-
y_norms = y_norms2;
|
|
392
|
-
}
|
|
393
|
-
|
|
394
|
-
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
395
|
-
size_t i1 = i0 + bs_x;
|
|
396
|
-
if (i1 > nx) {
|
|
397
|
-
i1 = nx;
|
|
398
|
-
}
|
|
399
|
-
|
|
400
|
-
res.begin_multiple(i0, i1);
|
|
401
|
-
|
|
402
|
-
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
403
|
-
size_t j1 = j0 + bs_y;
|
|
404
|
-
if (j1 > ny) {
|
|
405
|
-
j1 = ny;
|
|
406
|
-
}
|
|
407
|
-
/* compute the actual dot products */
|
|
408
|
-
{
|
|
409
|
-
float one = 1, zero = 0;
|
|
410
|
-
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
411
|
-
sgemm_("Transpose",
|
|
412
|
-
"Not transpose",
|
|
413
|
-
&nyi,
|
|
414
|
-
&nxi,
|
|
415
|
-
&di,
|
|
416
|
-
&one,
|
|
417
|
-
y + j0 * d,
|
|
418
|
-
&di,
|
|
419
|
-
x + i0 * d,
|
|
420
|
-
&di,
|
|
421
|
-
&zero,
|
|
422
|
-
ip_block.get(),
|
|
423
|
-
&nyi);
|
|
424
|
-
}
|
|
425
|
-
for (int64_t i = i0; i < i1; i++) {
|
|
426
|
-
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
427
|
-
|
|
428
|
-
_mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
|
|
429
|
-
_mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
|
|
430
|
-
|
|
431
|
-
// constant
|
|
432
|
-
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
|
433
|
-
|
|
434
|
-
// Track 8 min distances + 8 min indices.
|
|
435
|
-
// All the distances tracked do not take x_norms[i]
|
|
436
|
-
// into account in order to get rid of extra
|
|
437
|
-
// _mm256_add_ps(x_norms[i], ...) instructions
|
|
438
|
-
// is distance computations.
|
|
439
|
-
__m256 min_distances =
|
|
440
|
-
_mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
|
|
441
|
-
|
|
442
|
-
// these indices are local and are relative to j0.
|
|
443
|
-
// so, value 0 means j0.
|
|
444
|
-
__m256i min_indices = _mm256_set1_epi32(0);
|
|
445
|
-
|
|
446
|
-
__m256i current_indices =
|
|
447
|
-
_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
448
|
-
const __m256i indices_delta = _mm256_set1_epi32(8);
|
|
449
|
-
|
|
450
|
-
// current j index
|
|
451
|
-
size_t idx_j = 0;
|
|
452
|
-
size_t count = j1 - j0;
|
|
453
|
-
|
|
454
|
-
// process 16 elements per loop
|
|
455
|
-
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
|
456
|
-
_mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
|
|
457
|
-
_mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
|
|
458
|
-
|
|
459
|
-
// load values for norms
|
|
460
|
-
const __m256 y_norm_0 =
|
|
461
|
-
_mm256_loadu_ps(y_norms + idx_j + j0 + 0);
|
|
462
|
-
const __m256 y_norm_1 =
|
|
463
|
-
_mm256_loadu_ps(y_norms + idx_j + j0 + 8);
|
|
464
|
-
|
|
465
|
-
// load values for dot products
|
|
466
|
-
const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
|
|
467
|
-
const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
|
|
468
|
-
|
|
469
|
-
// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
|
|
470
|
-
// x_norm[i] was dropped off because it is a constant for a
|
|
471
|
-
// given i. We'll deal with it later.
|
|
472
|
-
__m256 distances_0 =
|
|
473
|
-
_mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
|
|
474
|
-
__m256 distances_1 =
|
|
475
|
-
_mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
|
|
476
|
-
|
|
477
|
-
// compare the new distances to the min distances
|
|
478
|
-
// for each of the first group of 8 AVX2 components.
|
|
479
|
-
const __m256 comparison_0 = _mm256_cmp_ps(
|
|
480
|
-
min_distances, distances_0, _CMP_LE_OS);
|
|
481
|
-
|
|
482
|
-
// update min distances and indices with closest vectors if
|
|
483
|
-
// needed.
|
|
484
|
-
min_distances = _mm256_blendv_ps(
|
|
485
|
-
distances_0, min_distances, comparison_0);
|
|
486
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
487
|
-
_mm256_castsi256_ps(current_indices),
|
|
488
|
-
_mm256_castsi256_ps(min_indices),
|
|
489
|
-
comparison_0));
|
|
490
|
-
current_indices =
|
|
491
|
-
_mm256_add_epi32(current_indices, indices_delta);
|
|
492
|
-
|
|
493
|
-
// compare the new distances to the min distances
|
|
494
|
-
// for each of the second group of 8 AVX2 components.
|
|
495
|
-
const __m256 comparison_1 = _mm256_cmp_ps(
|
|
496
|
-
min_distances, distances_1, _CMP_LE_OS);
|
|
497
|
-
|
|
498
|
-
// update min distances and indices with closest vectors if
|
|
499
|
-
// needed.
|
|
500
|
-
min_distances = _mm256_blendv_ps(
|
|
501
|
-
distances_1, min_distances, comparison_1);
|
|
502
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
503
|
-
_mm256_castsi256_ps(current_indices),
|
|
504
|
-
_mm256_castsi256_ps(min_indices),
|
|
505
|
-
comparison_1));
|
|
506
|
-
current_indices =
|
|
507
|
-
_mm256_add_epi32(current_indices, indices_delta);
|
|
508
|
-
}
|
|
509
|
-
|
|
510
|
-
// dump values and find the minimum distance / minimum index
|
|
511
|
-
float min_distances_scalar[8];
|
|
512
|
-
uint32_t min_indices_scalar[8];
|
|
513
|
-
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
514
|
-
_mm256_storeu_si256(
|
|
515
|
-
(__m256i*)(min_indices_scalar), min_indices);
|
|
516
|
-
|
|
517
|
-
float current_min_distance = res.dis_tab[i];
|
|
518
|
-
uint32_t current_min_index = res.ids_tab[i];
|
|
519
|
-
|
|
520
|
-
// This unusual comparison is needed to maintain the behavior
|
|
521
|
-
// of the original implementation: if two indices are
|
|
522
|
-
// represented with equal distance values, then
|
|
523
|
-
// the index with the min value is returned.
|
|
524
|
-
for (size_t jv = 0; jv < 8; jv++) {
|
|
525
|
-
// add missing x_norms[i]
|
|
526
|
-
float distance_candidate =
|
|
527
|
-
min_distances_scalar[jv] + x_norms[i];
|
|
528
|
-
|
|
529
|
-
// negative values can occur for identical vectors
|
|
530
|
-
// due to roundoff errors.
|
|
531
|
-
if (distance_candidate < 0) {
|
|
532
|
-
distance_candidate = 0;
|
|
533
|
-
}
|
|
534
|
-
|
|
535
|
-
int64_t index_candidate = min_indices_scalar[jv] + j0;
|
|
536
|
-
|
|
537
|
-
if (current_min_distance > distance_candidate) {
|
|
538
|
-
current_min_distance = distance_candidate;
|
|
539
|
-
current_min_index = index_candidate;
|
|
540
|
-
} else if (
|
|
541
|
-
current_min_distance == distance_candidate &&
|
|
542
|
-
current_min_index > index_candidate) {
|
|
543
|
-
current_min_index = index_candidate;
|
|
544
|
-
}
|
|
545
|
-
}
|
|
546
|
-
|
|
547
|
-
// process leftovers
|
|
548
|
-
for (; idx_j < count; idx_j++, ip_line++) {
|
|
549
|
-
float ip = *ip_line;
|
|
550
|
-
float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
|
|
551
|
-
// negative values can occur for identical vectors
|
|
552
|
-
// due to roundoff errors.
|
|
553
|
-
if (dis < 0) {
|
|
554
|
-
dis = 0;
|
|
555
|
-
}
|
|
556
|
-
|
|
557
|
-
if (current_min_distance > dis) {
|
|
558
|
-
current_min_distance = dis;
|
|
559
|
-
current_min_index = idx_j + j0;
|
|
560
|
-
}
|
|
561
|
-
}
|
|
562
|
-
|
|
563
|
-
//
|
|
564
|
-
res.add_result(i, current_min_distance, current_min_index);
|
|
565
|
-
}
|
|
566
|
-
}
|
|
567
|
-
// Does nothing for SingleBestResultHandler, but
|
|
568
|
-
// keeping the call for the consistency.
|
|
569
|
-
res.end_multiple();
|
|
570
|
-
InterruptCallback::check();
|
|
571
|
-
}
|
|
512
|
+
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res, y_norms);
|
|
572
513
|
}
|
|
573
|
-
#elif defined(__ARM_FEATURE_SVE)
|
|
574
|
-
void exhaustive_L2sqr_blas_cmax_sve(
|
|
575
|
-
const float* x,
|
|
576
|
-
const float* y,
|
|
577
|
-
size_t d,
|
|
578
|
-
size_t nx,
|
|
579
|
-
size_t ny,
|
|
580
|
-
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
581
|
-
const float* y_norms) {
|
|
582
|
-
// BLAS does not like empty matrices
|
|
583
|
-
if (nx == 0 || ny == 0)
|
|
584
|
-
return;
|
|
585
514
|
|
|
586
|
-
|
|
587
|
-
const size_t bs_x = distance_compute_blas_query_bs;
|
|
588
|
-
const size_t bs_y = distance_compute_blas_database_bs;
|
|
589
|
-
// const size_t bs_x = 16, bs_y = 16;
|
|
590
|
-
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
591
|
-
std::unique_ptr<float[]> x_norms(new float[nx]);
|
|
592
|
-
std::unique_ptr<float[]> del2;
|
|
593
|
-
|
|
594
|
-
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
|
595
|
-
|
|
596
|
-
const size_t lanes = svcntw();
|
|
597
|
-
|
|
598
|
-
if (!y_norms) {
|
|
599
|
-
float* y_norms2 = new float[ny];
|
|
600
|
-
del2.reset(y_norms2);
|
|
601
|
-
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
|
602
|
-
y_norms = y_norms2;
|
|
603
|
-
}
|
|
604
|
-
|
|
605
|
-
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
606
|
-
size_t i1 = i0 + bs_x;
|
|
607
|
-
if (i1 > nx)
|
|
608
|
-
i1 = nx;
|
|
609
|
-
|
|
610
|
-
res.begin_multiple(i0, i1);
|
|
611
|
-
|
|
612
|
-
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
613
|
-
size_t j1 = j0 + bs_y;
|
|
614
|
-
if (j1 > ny)
|
|
615
|
-
j1 = ny;
|
|
616
|
-
/* compute the actual dot products */
|
|
617
|
-
{
|
|
618
|
-
float one = 1, zero = 0;
|
|
619
|
-
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
620
|
-
sgemm_("Transpose",
|
|
621
|
-
"Not transpose",
|
|
622
|
-
&nyi,
|
|
623
|
-
&nxi,
|
|
624
|
-
&di,
|
|
625
|
-
&one,
|
|
626
|
-
y + j0 * d,
|
|
627
|
-
&di,
|
|
628
|
-
x + i0 * d,
|
|
629
|
-
&di,
|
|
630
|
-
&zero,
|
|
631
|
-
ip_block.get(),
|
|
632
|
-
&nyi);
|
|
633
|
-
}
|
|
634
|
-
for (int64_t i = i0; i < i1; i++) {
|
|
635
|
-
const size_t count = j1 - j0;
|
|
636
|
-
float* ip_line = ip_block.get() + (i - i0) * count;
|
|
637
|
-
|
|
638
|
-
svprfw(svwhilelt_b32_u64(0, count), ip_line, SV_PLDL1KEEP);
|
|
639
|
-
svprfw(svwhilelt_b32_u64(lanes, count),
|
|
640
|
-
ip_line + lanes,
|
|
641
|
-
SV_PLDL1KEEP);
|
|
642
|
-
|
|
643
|
-
// Track lanes min distances + lanes min indices.
|
|
644
|
-
// All the distances tracked do not take x_norms[i]
|
|
645
|
-
// into account in order to get rid of extra
|
|
646
|
-
// vaddq_f32(x_norms[i], ...) instructions
|
|
647
|
-
// is distance computations.
|
|
648
|
-
auto min_distances = svdup_n_f32(res.dis_tab[i] - x_norms[i]);
|
|
649
|
-
|
|
650
|
-
// these indices are local and are relative to j0.
|
|
651
|
-
// so, value 0 means j0.
|
|
652
|
-
auto min_indices = svdup_n_u32(0u);
|
|
653
|
-
|
|
654
|
-
auto current_indices = svindex_u32(0u, 1u);
|
|
655
|
-
|
|
656
|
-
// process lanes * 2 elements per loop
|
|
657
|
-
for (size_t idx_j = 0; idx_j < count;
|
|
658
|
-
idx_j += lanes * 2, ip_line += lanes * 2) {
|
|
659
|
-
svprfw(svwhilelt_b32_u64(idx_j + lanes * 2, count),
|
|
660
|
-
ip_line + lanes * 2,
|
|
661
|
-
SV_PLDL1KEEP);
|
|
662
|
-
svprfw(svwhilelt_b32_u64(idx_j + lanes * 3, count),
|
|
663
|
-
ip_line + lanes * 3,
|
|
664
|
-
SV_PLDL1KEEP);
|
|
665
|
-
|
|
666
|
-
// mask
|
|
667
|
-
const auto mask_0 = svwhilelt_b32_u64(idx_j, count);
|
|
668
|
-
const auto mask_1 = svwhilelt_b32_u64(idx_j + lanes, count);
|
|
669
|
-
|
|
670
|
-
// load values for norms
|
|
671
|
-
const auto y_norm_0 =
|
|
672
|
-
svld1_f32(mask_0, y_norms + idx_j + j0 + 0);
|
|
673
|
-
const auto y_norm_1 =
|
|
674
|
-
svld1_f32(mask_1, y_norms + idx_j + j0 + lanes);
|
|
675
|
-
|
|
676
|
-
// load values for dot products
|
|
677
|
-
const auto ip_0 = svld1_f32(mask_0, ip_line + 0);
|
|
678
|
-
const auto ip_1 = svld1_f32(mask_1, ip_line + lanes);
|
|
679
|
-
|
|
680
|
-
// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
|
|
681
|
-
// x_norm[i] was dropped off because it is a constant for a
|
|
682
|
-
// given i. We'll deal with it later.
|
|
683
|
-
const auto distances_0 =
|
|
684
|
-
svmla_n_f32_z(mask_0, y_norm_0, ip_0, -2.f);
|
|
685
|
-
const auto distances_1 =
|
|
686
|
-
svmla_n_f32_z(mask_1, y_norm_1, ip_1, -2.f);
|
|
687
|
-
|
|
688
|
-
// compare the new distances to the min distances
|
|
689
|
-
// for each of the first group of 4 ARM SIMD components.
|
|
690
|
-
auto comparison =
|
|
691
|
-
svcmpgt_f32(mask_0, min_distances, distances_0);
|
|
692
|
-
|
|
693
|
-
// update min distances and indices with closest vectors if
|
|
694
|
-
// needed.
|
|
695
|
-
min_distances =
|
|
696
|
-
svsel_f32(comparison, distances_0, min_distances);
|
|
697
|
-
min_indices =
|
|
698
|
-
svsel_u32(comparison, current_indices, min_indices);
|
|
699
|
-
current_indices = svadd_n_u32_x(
|
|
700
|
-
mask_0,
|
|
701
|
-
current_indices,
|
|
702
|
-
static_cast<uint32_t>(lanes));
|
|
703
|
-
|
|
704
|
-
// compare the new distances to the min distances
|
|
705
|
-
// for each of the second group of 4 ARM SIMD components.
|
|
706
|
-
comparison =
|
|
707
|
-
svcmpgt_f32(mask_1, min_distances, distances_1);
|
|
708
|
-
|
|
709
|
-
// update min distances and indices with closest vectors if
|
|
710
|
-
// needed.
|
|
711
|
-
min_distances =
|
|
712
|
-
svsel_f32(comparison, distances_1, min_distances);
|
|
713
|
-
min_indices =
|
|
714
|
-
svsel_u32(comparison, current_indices, min_indices);
|
|
715
|
-
current_indices = svadd_n_u32_x(
|
|
716
|
-
mask_1,
|
|
717
|
-
current_indices,
|
|
718
|
-
static_cast<uint32_t>(lanes));
|
|
719
|
-
}
|
|
515
|
+
} // anonymous namespace
|
|
720
516
|
|
|
721
|
-
|
|
722
|
-
// negative values can occur for identical vectors
|
|
723
|
-
// due to roundoff errors.
|
|
724
|
-
auto mask = svwhilelt_b32_u64(0, count);
|
|
725
|
-
min_distances = svadd_n_f32_z(
|
|
726
|
-
svcmpge_n_f32(mask, min_distances, -x_norms[i]),
|
|
727
|
-
min_distances,
|
|
728
|
-
x_norms[i]);
|
|
729
|
-
min_indices = svadd_n_u32_x(
|
|
730
|
-
mask, min_indices, static_cast<uint32_t>(j0));
|
|
731
|
-
mask = svcmple_n_f32(mask, min_distances, res.dis_tab[i]);
|
|
732
|
-
if (svcntp_b32(svptrue_b32(), mask) == 0)
|
|
733
|
-
res.add_result(i, res.dis_tab[i], res.ids_tab[i]);
|
|
734
|
-
else {
|
|
735
|
-
const auto min_distance = svminv_f32(mask, min_distances);
|
|
736
|
-
const auto min_index = svminv_u32(
|
|
737
|
-
svcmpeq_n_f32(mask, min_distances, min_distance),
|
|
738
|
-
min_indices);
|
|
739
|
-
res.add_result(i, min_distance, min_index);
|
|
740
|
-
}
|
|
741
|
-
}
|
|
742
|
-
}
|
|
743
|
-
// Does nothing for SingleBestResultHandler, but
|
|
744
|
-
// keeping the call for the consistency.
|
|
745
|
-
res.end_multiple();
|
|
746
|
-
InterruptCallback::check();
|
|
747
|
-
}
|
|
748
|
-
}
|
|
749
|
-
#endif
|
|
517
|
+
namespace {
|
|
750
518
|
|
|
751
519
|
// an override if only a single closest point is needed
|
|
752
520
|
template <>
|
|
@@ -758,43 +526,20 @@ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
|
758
526
|
size_t ny,
|
|
759
527
|
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
760
528
|
const float* y_norms) {
|
|
761
|
-
#if defined(__AVX2__)
|
|
762
|
-
// use a faster fused kernel if available
|
|
763
|
-
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
|
|
764
|
-
// the kernel is available and it is complete, we're done.
|
|
765
|
-
return;
|
|
766
|
-
}
|
|
767
|
-
|
|
768
|
-
// run the specialized AVX2 implementation
|
|
769
|
-
exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms);
|
|
770
|
-
|
|
771
|
-
#elif defined(__ARM_FEATURE_SVE)
|
|
772
|
-
// use a faster fused kernel if available
|
|
773
|
-
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
|
|
774
|
-
// the kernel is available and it is complete, we're done.
|
|
775
|
-
return;
|
|
776
|
-
}
|
|
777
|
-
|
|
778
|
-
// run the specialized SVE implementation
|
|
779
|
-
exhaustive_L2sqr_blas_cmax_sve(x, y, d, nx, ny, res, y_norms);
|
|
780
|
-
|
|
781
|
-
#elif defined(__aarch64__)
|
|
782
529
|
// use a faster fused kernel if available
|
|
783
530
|
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
|
|
784
|
-
// the kernel is available and it is complete, we're done.
|
|
785
531
|
return;
|
|
786
532
|
}
|
|
787
533
|
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
#endif
|
|
534
|
+
with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A2>([&]<SIMDLevel SL>() {
|
|
535
|
+
if constexpr (SL == SIMDLevel::AVX2 || SL == SIMDLevel::ARM_SVE) {
|
|
536
|
+
exhaustive_L2sqr_blas_cmax<SL>(x, y, d, nx, ny, res, y_norms);
|
|
537
|
+
} else {
|
|
538
|
+
exhaustive_L2sqr_blas_default_impl<
|
|
539
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
540
|
+
x, y, d, nx, ny, res, y_norms);
|
|
541
|
+
}
|
|
542
|
+
});
|
|
798
543
|
}
|
|
799
544
|
|
|
800
545
|
struct Run_search_inner_product {
|
|
@@ -806,7 +551,8 @@ struct Run_search_inner_product {
|
|
|
806
551
|
size_t d,
|
|
807
552
|
size_t nx,
|
|
808
553
|
size_t ny) {
|
|
809
|
-
if (res.sel ||
|
|
554
|
+
if (res.sel ||
|
|
555
|
+
nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
|
|
810
556
|
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
811
557
|
} else {
|
|
812
558
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
@@ -824,7 +570,8 @@ struct Run_search_L2sqr {
|
|
|
824
570
|
size_t nx,
|
|
825
571
|
size_t ny,
|
|
826
572
|
const float* y_norm2) {
|
|
827
|
-
if (res.sel ||
|
|
573
|
+
if (res.sel ||
|
|
574
|
+
nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
|
|
828
575
|
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
829
576
|
} else {
|
|
830
577
|
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
@@ -838,11 +585,174 @@ struct Run_search_L2sqr {
|
|
|
838
585
|
* KNN driver functions
|
|
839
586
|
*******************************************************/
|
|
840
587
|
|
|
841
|
-
int distance_compute_blas_threshold =
|
|
588
|
+
int distance_compute_blas_threshold = 128000;
|
|
842
589
|
int distance_compute_blas_query_bs = 4096;
|
|
843
590
|
int distance_compute_blas_database_bs = 1024;
|
|
844
591
|
int distance_compute_min_k_reservoir = 100;
|
|
845
592
|
|
|
593
|
+
// Database-parallel KNN: parallelizes over database segments instead of
|
|
594
|
+
// queries, for the case where nx < nthreads and the database is large.
|
|
595
|
+
static constexpr size_t kDbParallelMinVectors = 10000;
|
|
596
|
+
|
|
597
|
+
template <class C>
|
|
598
|
+
static void knn_db_parallel_impl(
|
|
599
|
+
const float* x,
|
|
600
|
+
const float* y,
|
|
601
|
+
size_t d,
|
|
602
|
+
size_t nx,
|
|
603
|
+
size_t ny,
|
|
604
|
+
size_t k,
|
|
605
|
+
float* vals,
|
|
606
|
+
int64_t* ids,
|
|
607
|
+
const float* y_norms) {
|
|
608
|
+
using T = typename C::T;
|
|
609
|
+
using TI = typename C::TI;
|
|
610
|
+
|
|
611
|
+
int nt = omp_get_max_threads();
|
|
612
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
|
613
|
+
|
|
614
|
+
// Per-thread result heaps: nt threads x nx queries x k results
|
|
615
|
+
std::vector<T> all_dis(static_cast<size_t>(nt) * nx * k);
|
|
616
|
+
std::vector<TI> all_ids(static_cast<size_t>(nt) * nx * k);
|
|
617
|
+
|
|
618
|
+
std::unique_ptr<float[]> x_norms_storage;
|
|
619
|
+
std::unique_ptr<float[]> y_norms_storage;
|
|
620
|
+
const float* x_norms = nullptr;
|
|
621
|
+
// C::is_max corresponds to L2 (CMax), not IP (CMin)
|
|
622
|
+
if constexpr (C::is_max) {
|
|
623
|
+
x_norms_storage.reset(new float[nx]);
|
|
624
|
+
fvec_norms_L2sqr(x_norms_storage.get(), x, d, nx);
|
|
625
|
+
x_norms = x_norms_storage.get();
|
|
626
|
+
|
|
627
|
+
if (!y_norms) {
|
|
628
|
+
y_norms_storage.reset(new float[ny]);
|
|
629
|
+
y_norms = y_norms_storage.get();
|
|
630
|
+
}
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
#pragma omp parallel num_threads(nt)
|
|
634
|
+
{
|
|
635
|
+
int tid = omp_get_thread_num();
|
|
636
|
+
size_t j_begin = static_cast<size_t>(tid) * ny / nt;
|
|
637
|
+
size_t j_end = static_cast<size_t>(tid + 1) * ny / nt;
|
|
638
|
+
size_t local_ny = j_end - j_begin;
|
|
639
|
+
|
|
640
|
+
// Compute y_norms for this thread's segment (cache locality)
|
|
641
|
+
if constexpr (C::is_max) {
|
|
642
|
+
if (y_norms_storage && local_ny > 0) {
|
|
643
|
+
fvec_norms_L2sqr(
|
|
644
|
+
y_norms_storage.get() + j_begin,
|
|
645
|
+
y + j_begin * d,
|
|
646
|
+
d,
|
|
647
|
+
local_ny);
|
|
648
|
+
}
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
T* my_dis = all_dis.data() + tid * nx * k;
|
|
652
|
+
TI* my_ids = all_ids.data() + tid * nx * k;
|
|
653
|
+
|
|
654
|
+
// Each thread initializes its own heaps
|
|
655
|
+
for (size_t i = 0; i < nx; i++) {
|
|
656
|
+
heap_heapify<C>(k, my_dis + i * k, my_ids + i * k);
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
if (local_ny > 0) {
|
|
660
|
+
size_t max_block = std::min(bs_y, local_ny);
|
|
661
|
+
std::unique_ptr<float[]> ip_block(new float[nx * max_block]);
|
|
662
|
+
|
|
663
|
+
for (size_t jj0 = 0; jj0 < local_ny; jj0 += bs_y) {
|
|
664
|
+
size_t jj1 = std::min(jj0 + bs_y, local_ny);
|
|
665
|
+
size_t block_ny = jj1 - jj0;
|
|
666
|
+
|
|
667
|
+
{
|
|
668
|
+
float one = 1, zero = 0;
|
|
669
|
+
FINTEGER nyi = static_cast<FINTEGER>(block_ny);
|
|
670
|
+
FINTEGER nxi = static_cast<FINTEGER>(nx);
|
|
671
|
+
FINTEGER di = static_cast<FINTEGER>(d);
|
|
672
|
+
sgemm_("Transpose",
|
|
673
|
+
"Not transpose",
|
|
674
|
+
&nyi,
|
|
675
|
+
&nxi,
|
|
676
|
+
&di,
|
|
677
|
+
&one,
|
|
678
|
+
y + (j_begin + jj0) * d,
|
|
679
|
+
&di,
|
|
680
|
+
x,
|
|
681
|
+
&di,
|
|
682
|
+
&zero,
|
|
683
|
+
ip_block.get(),
|
|
684
|
+
&nyi);
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
for (size_t i = 0; i < nx; i++) {
|
|
688
|
+
T* heap_dis = my_dis + i * k;
|
|
689
|
+
TI* heap_ids = my_ids + i * k;
|
|
690
|
+
const float* ip_line = ip_block.get() + i * block_ny;
|
|
691
|
+
T thresh = heap_dis[0];
|
|
692
|
+
|
|
693
|
+
for (size_t jj = 0; jj < block_ny; jj++) {
|
|
694
|
+
size_t global_j = j_begin + jj0 + jj;
|
|
695
|
+
float ip = ip_line[jj];
|
|
696
|
+
T dis;
|
|
697
|
+
|
|
698
|
+
if constexpr (C::is_max) {
|
|
699
|
+
dis = x_norms[i] + y_norms[global_j] - 2 * ip;
|
|
700
|
+
if (dis < 0) {
|
|
701
|
+
dis = 0;
|
|
702
|
+
}
|
|
703
|
+
} else {
|
|
704
|
+
dis = ip;
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
if (C::cmp(thresh, dis)) {
|
|
708
|
+
heap_replace_top<C>(
|
|
709
|
+
k, heap_dis, heap_ids, dis, global_j);
|
|
710
|
+
thresh = heap_dis[0];
|
|
711
|
+
}
|
|
712
|
+
}
|
|
713
|
+
}
|
|
714
|
+
}
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
// Merge per-thread heaps into output, parallelized over queries
|
|
719
|
+
#pragma omp parallel for
|
|
720
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
721
|
+
heap_heapify<C>(k, vals + i * k, ids + i * k);
|
|
722
|
+
|
|
723
|
+
for (int t = 0; t < nt; t++) {
|
|
724
|
+
T* t_dis = all_dis.data() + (t * nx + i) * k;
|
|
725
|
+
TI* t_ids = all_ids.data() + (t * nx + i) * k;
|
|
726
|
+
T* out_dis = vals + i * k;
|
|
727
|
+
TI* out_ids = ids + i * k;
|
|
728
|
+
|
|
729
|
+
for (size_t j = 0; j < k; j++) {
|
|
730
|
+
if (t_ids[j] >= 0 && C::cmp(out_dis[0], t_dis[j])) {
|
|
731
|
+
heap_replace_top<C>(
|
|
732
|
+
k, out_dis, out_ids, t_dis[j], t_ids[j]);
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
heap_reorder<C>(k, vals + i * k, ids + i * k);
|
|
738
|
+
}
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
static bool should_use_db_parallel(
|
|
742
|
+
size_t nx,
|
|
743
|
+
size_t ny,
|
|
744
|
+
const IDSelector* sel) {
|
|
745
|
+
if (sel) {
|
|
746
|
+
return false;
|
|
747
|
+
}
|
|
748
|
+
int nt = omp_get_max_threads();
|
|
749
|
+
size_t min_ny = std::max(
|
|
750
|
+
kDbParallelMinVectors,
|
|
751
|
+
static_cast<size_t>(nt) *
|
|
752
|
+
static_cast<size_t>(distance_compute_blas_database_bs));
|
|
753
|
+
return nt > 1 && nx < static_cast<size_t>(nt) && ny >= min_ny;
|
|
754
|
+
}
|
|
755
|
+
|
|
846
756
|
void knn_inner_product(
|
|
847
757
|
const float* x,
|
|
848
758
|
const float* y,
|
|
@@ -867,9 +777,26 @@ void knn_inner_product(
|
|
|
867
777
|
return;
|
|
868
778
|
}
|
|
869
779
|
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
780
|
+
if (should_use_db_parallel(nx, ny, sel)) {
|
|
781
|
+
knn_db_parallel_impl<CMin<float, int64_t>>(
|
|
782
|
+
x, y, d, nx, ny, k, vals, ids, nullptr);
|
|
783
|
+
} else {
|
|
784
|
+
Run_search_inner_product r;
|
|
785
|
+
// @lint-ignore CLANGTIDY facebook-hte-NullableDereference
|
|
786
|
+
dispatch_knn_ResultHandler(
|
|
787
|
+
nx,
|
|
788
|
+
vals,
|
|
789
|
+
ids,
|
|
790
|
+
k,
|
|
791
|
+
METRIC_INNER_PRODUCT,
|
|
792
|
+
sel,
|
|
793
|
+
r,
|
|
794
|
+
x,
|
|
795
|
+
y,
|
|
796
|
+
d,
|
|
797
|
+
nx,
|
|
798
|
+
ny);
|
|
799
|
+
}
|
|
873
800
|
|
|
874
801
|
if (imin != 0) {
|
|
875
802
|
for (size_t i = 0; i < nx * k; i++) {
|
|
@@ -916,9 +843,15 @@ void knn_L2sqr(
|
|
|
916
843
|
return;
|
|
917
844
|
}
|
|
918
845
|
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
846
|
+
if (should_use_db_parallel(nx, ny, sel)) {
|
|
847
|
+
knn_db_parallel_impl<CMax<float, int64_t>>(
|
|
848
|
+
x, y, d, nx, ny, k, vals, ids, y_norm2);
|
|
849
|
+
} else {
|
|
850
|
+
Run_search_L2sqr r;
|
|
851
|
+
// @lint-ignore CLANGTIDY facebook-hte-NullableDereference
|
|
852
|
+
dispatch_knn_ResultHandler(
|
|
853
|
+
nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
|
|
854
|
+
}
|
|
922
855
|
|
|
923
856
|
if (imin != 0) {
|
|
924
857
|
for (size_t i = 0; i < nx * k; i++) {
|
|
@@ -989,19 +922,21 @@ void fvec_inner_products_by_idx(
|
|
|
989
922
|
size_t d,
|
|
990
923
|
size_t nx,
|
|
991
924
|
size_t ny) {
|
|
925
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
992
926
|
#pragma omp parallel for
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
927
|
+
for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
|
|
928
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
|
929
|
+
const float* xj = x + j * d;
|
|
930
|
+
float* __restrict ipj = ip + j * ny;
|
|
931
|
+
for (size_t i = 0; i < ny; i++) {
|
|
932
|
+
if (idsj[i] < 0) {
|
|
933
|
+
ipj[i] = -INFINITY;
|
|
934
|
+
} else {
|
|
935
|
+
ipj[i] = fvec_inner_product<SL>(xj, y + d * idsj[i], d);
|
|
936
|
+
}
|
|
1002
937
|
}
|
|
1003
938
|
}
|
|
1004
|
-
}
|
|
939
|
+
});
|
|
1005
940
|
}
|
|
1006
941
|
|
|
1007
942
|
/* compute the inner product between x and a subset y of ny vectors,
|
|
@@ -1014,19 +949,21 @@ void fvec_L2sqr_by_idx(
|
|
|
1014
949
|
size_t d,
|
|
1015
950
|
size_t nx,
|
|
1016
951
|
size_t ny) {
|
|
952
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1017
953
|
#pragma omp parallel for
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
954
|
+
for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
|
|
955
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
|
956
|
+
const float* xj = x + j * d;
|
|
957
|
+
float* __restrict disj = dis + j * ny;
|
|
958
|
+
for (size_t i = 0; i < ny; i++) {
|
|
959
|
+
if (idsj[i] < 0) {
|
|
960
|
+
disj[i] = INFINITY;
|
|
961
|
+
} else {
|
|
962
|
+
disj[i] = fvec_L2sqr<SL>(xj, y + d * idsj[i], d);
|
|
963
|
+
}
|
|
1027
964
|
}
|
|
1028
965
|
}
|
|
1029
|
-
}
|
|
966
|
+
});
|
|
1030
967
|
}
|
|
1031
968
|
|
|
1032
969
|
void pairwise_indexed_L2sqr(
|
|
@@ -1037,14 +974,16 @@ void pairwise_indexed_L2sqr(
|
|
|
1037
974
|
const float* y,
|
|
1038
975
|
const int64_t* iy,
|
|
1039
976
|
float* dis) {
|
|
977
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1040
978
|
#pragma omp parallel for if (n > 1)
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
979
|
+
for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
|
|
980
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
981
|
+
dis[j] = fvec_L2sqr<SL>(x + d * ix[j], y + d * iy[j], d);
|
|
982
|
+
} else {
|
|
983
|
+
dis[j] = INFINITY;
|
|
984
|
+
}
|
|
1046
985
|
}
|
|
1047
|
-
}
|
|
986
|
+
});
|
|
1048
987
|
}
|
|
1049
988
|
|
|
1050
989
|
void pairwise_indexed_inner_product(
|
|
@@ -1055,14 +994,17 @@ void pairwise_indexed_inner_product(
|
|
|
1055
994
|
const float* y,
|
|
1056
995
|
const int64_t* iy,
|
|
1057
996
|
float* dis) {
|
|
997
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1058
998
|
#pragma omp parallel for if (n > 1)
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
999
|
+
for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
|
|
1000
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
1001
|
+
dis[j] =
|
|
1002
|
+
fvec_inner_product<SL>(x + d * ix[j], y + d * iy[j], d);
|
|
1003
|
+
} else {
|
|
1004
|
+
dis[j] = -INFINITY;
|
|
1005
|
+
}
|
|
1064
1006
|
}
|
|
1065
|
-
}
|
|
1007
|
+
});
|
|
1066
1008
|
}
|
|
1067
1009
|
|
|
1068
1010
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
|
@@ -1083,27 +1025,29 @@ void knn_inner_products_by_idx(
|
|
|
1083
1025
|
ld_ids = ny;
|
|
1084
1026
|
}
|
|
1085
1027
|
|
|
1028
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1086
1029
|
#pragma omp parallel for if (nx > 100)
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1030
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
1031
|
+
const float* x_ = x + i * d;
|
|
1032
|
+
const int64_t* idsi = ids + i * ld_ids;
|
|
1033
|
+
size_t j;
|
|
1034
|
+
float* __restrict simi = res_vals + i * k;
|
|
1035
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
1036
|
+
minheap_heapify(k, simi, idxi);
|
|
1037
|
+
|
|
1038
|
+
for (j = 0; j < nsubset; j++) {
|
|
1039
|
+
if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
|
|
1040
|
+
break;
|
|
1041
|
+
}
|
|
1042
|
+
float ip = fvec_inner_product<SL>(x_, y + d * idsi[j], d);
|
|
1100
1043
|
|
|
1101
|
-
|
|
1102
|
-
|
|
1044
|
+
if (ip > simi[0]) {
|
|
1045
|
+
minheap_replace_top(k, simi, idxi, ip, idsi[j]);
|
|
1046
|
+
}
|
|
1103
1047
|
}
|
|
1048
|
+
minheap_reorder(k, simi, idxi);
|
|
1104
1049
|
}
|
|
1105
|
-
|
|
1106
|
-
}
|
|
1050
|
+
});
|
|
1107
1051
|
}
|
|
1108
1052
|
|
|
1109
1053
|
void knn_L2sqr_by_idx(
|
|
@@ -1121,25 +1065,27 @@ void knn_L2sqr_by_idx(
|
|
|
1121
1065
|
if (ld_ids < 0) {
|
|
1122
1066
|
ld_ids = ny;
|
|
1123
1067
|
}
|
|
1068
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1124
1069
|
#pragma omp parallel for if (nx > 100)
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1070
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
1071
|
+
const float* x_ = x + i * d;
|
|
1072
|
+
const int64_t* __restrict idsi = ids + i * ld_ids;
|
|
1073
|
+
float* __restrict simi = res_vals + i * k;
|
|
1074
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
1075
|
+
maxheap_heapify(k, simi, idxi);
|
|
1076
|
+
for (size_t j = 0; j < nsubset; j++) {
|
|
1077
|
+
if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
|
|
1078
|
+
break;
|
|
1079
|
+
}
|
|
1080
|
+
float disij = fvec_L2sqr<SL>(x_, y + d * idsi[j], d);
|
|
1136
1081
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1082
|
+
if (disij < simi[0]) {
|
|
1083
|
+
maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
|
|
1084
|
+
}
|
|
1139
1085
|
}
|
|
1086
|
+
maxheap_reorder(k, simi, idxi);
|
|
1140
1087
|
}
|
|
1141
|
-
|
|
1142
|
-
}
|
|
1088
|
+
});
|
|
1143
1089
|
}
|
|
1144
1090
|
|
|
1145
1091
|
void pairwise_L2sqr(
|
|
@@ -1168,25 +1114,27 @@ void pairwise_L2sqr(
|
|
|
1168
1114
|
// store in beginning of distance matrix to avoid malloc
|
|
1169
1115
|
float* b_norms = dis;
|
|
1170
1116
|
|
|
1117
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1171
1118
|
#pragma omp parallel for if (nb > 1)
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1119
|
+
for (int64_t i = 0; i < nb; i++) {
|
|
1120
|
+
b_norms[i] = fvec_norm_L2sqr<SL>(xb + i * ldb, d);
|
|
1121
|
+
}
|
|
1175
1122
|
|
|
1176
1123
|
#pragma omp parallel for
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1124
|
+
for (int64_t i = 1; i < nq; i++) {
|
|
1125
|
+
float q_norm = fvec_norm_L2sqr<SL>(xq + i * ldq, d);
|
|
1126
|
+
for (int64_t j = 0; j < nb; j++) {
|
|
1127
|
+
dis[i * ldd + j] = q_norm + b_norms[j];
|
|
1128
|
+
}
|
|
1181
1129
|
}
|
|
1182
|
-
}
|
|
1183
1130
|
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1131
|
+
{
|
|
1132
|
+
float q_norm = fvec_norm_L2sqr<SL>(xq, d);
|
|
1133
|
+
for (int64_t j = 0; j < nb; j++) {
|
|
1134
|
+
dis[j] += q_norm;
|
|
1135
|
+
}
|
|
1188
1136
|
}
|
|
1189
|
-
}
|
|
1137
|
+
});
|
|
1190
1138
|
|
|
1191
1139
|
{
|
|
1192
1140
|
FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
|
|
@@ -1215,7 +1163,7 @@ void inner_product_to_L2sqr(
|
|
|
1215
1163
|
size_t n1,
|
|
1216
1164
|
size_t n2) {
|
|
1217
1165
|
#pragma omp parallel for
|
|
1218
|
-
for (int64_t j = 0; j < n1; j++) {
|
|
1166
|
+
for (int64_t j = 0; j < static_cast<int64_t>(n1); j++) {
|
|
1219
1167
|
float* disj = dis + j * n2;
|
|
1220
1168
|
for (size_t i = 0; i < n2; i++) {
|
|
1221
1169
|
disj[i] = nr1[j] + nr2[i] - 2 * disj[i];
|