faiss 0.5.3 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +4 -4
- data/ext/faiss/index.cpp +63 -45
- data/ext/faiss/index_binary.cpp +37 -27
- data/ext/faiss/kmeans.cpp +9 -8
- data/ext/faiss/pca_matrix.cpp +9 -7
- data/ext/faiss/product_quantizer.cpp +13 -11
- data/ext/faiss/utils.cpp +4 -2
- data/ext/faiss/utils.h +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +214 -82
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +97 -249
- data/vendor/faiss/faiss/Clustering.h +18 -0
- data/vendor/faiss/faiss/IVFlib.cpp +67 -44
- data/vendor/faiss/faiss/Index.cpp +25 -12
- data/vendor/faiss/faiss/Index.h +26 -4
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
- data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
- data/vendor/faiss/faiss/IndexFastScan.h +35 -24
- data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
- data/vendor/faiss/faiss/IndexFlat.h +32 -14
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
- data/vendor/faiss/faiss/IndexHNSW.h +30 -14
- data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
- data/vendor/faiss/faiss/IndexIVF.h +47 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
- data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
- data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
- data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
- data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
- data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +13 -13
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +29 -6
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
- data/vendor/faiss/faiss/VectorTransform.h +39 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +55 -51
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
- data/vendor/faiss/faiss/impl/HNSW.h +21 -40
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
- data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
- data/vendor/faiss/faiss/impl/NSG.h +20 -10
- data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
- data/vendor/faiss/faiss/impl/Panorama.h +265 -78
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
- data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +115 -28
- data/vendor/faiss/faiss/index_io.h +53 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +507 -559
- data/vendor/faiss/faiss/utils/distances.h +118 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
- data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
- data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +21 -14
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +156 -42
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -150,7 +150,9 @@ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
|
|
|
150
150
|
}
|
|
151
151
|
|
|
152
152
|
void VectorTransform::check_identical(const VectorTransform& other) const {
|
|
153
|
-
|
|
153
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
154
|
+
other.d_in == d_in && other.d_out == d_out,
|
|
155
|
+
"transforms must have matching d_in and d_out");
|
|
154
156
|
}
|
|
155
157
|
|
|
156
158
|
/*********************************************
|
|
@@ -158,9 +160,9 @@ void VectorTransform::check_identical(const VectorTransform& other) const {
|
|
|
158
160
|
*********************************************/
|
|
159
161
|
|
|
160
162
|
/// both d_in > d_out and d_out < d_in are supported
|
|
161
|
-
LinearTransform::LinearTransform(int
|
|
162
|
-
: VectorTransform(
|
|
163
|
-
have_bias(
|
|
163
|
+
LinearTransform::LinearTransform(int din, int dout, bool have_bias_in)
|
|
164
|
+
: VectorTransform(din, dout),
|
|
165
|
+
have_bias(have_bias_in),
|
|
164
166
|
is_orthonormal(false),
|
|
165
167
|
verbose(false) {
|
|
166
168
|
is_trained = false; // will be trained when A and b are initialized
|
|
@@ -171,21 +173,25 @@ void LinearTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
171
173
|
|
|
172
174
|
float c_factor;
|
|
173
175
|
if (have_bias) {
|
|
174
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
176
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
177
|
+
b.size() == static_cast<size_t>(d_out), "Bias not initialized");
|
|
175
178
|
float* xi = xt;
|
|
176
|
-
for (
|
|
177
|
-
for (int j = 0; j < d_out; j++)
|
|
179
|
+
for (idx_t i = 0; i < n; i++) {
|
|
180
|
+
for (int j = 0; j < d_out; j++) {
|
|
178
181
|
*xi++ = b[j];
|
|
182
|
+
}
|
|
183
|
+
}
|
|
179
184
|
c_factor = 1.0;
|
|
180
185
|
} else {
|
|
181
186
|
c_factor = 0.0;
|
|
182
187
|
}
|
|
183
188
|
|
|
184
189
|
FAISS_THROW_IF_NOT_MSG(
|
|
185
|
-
A.size() == d_out * d_in,
|
|
190
|
+
A.size() == static_cast<size_t>(d_out) * d_in,
|
|
191
|
+
"Transformation matrix not initialized");
|
|
186
192
|
|
|
187
193
|
float one = 1;
|
|
188
|
-
FINTEGER nbiti = d_out, ni = n, di = d_in;
|
|
194
|
+
FINTEGER nbiti = d_out, ni = static_cast<FINTEGER>(n), di = d_in;
|
|
189
195
|
sgemm_("Transposed",
|
|
190
196
|
"Not transposed",
|
|
191
197
|
&nbiti,
|
|
@@ -203,20 +209,21 @@ void LinearTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
203
209
|
|
|
204
210
|
void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
|
|
205
211
|
const {
|
|
212
|
+
std::vector<float> y_bias_corrected;
|
|
206
213
|
if (have_bias) { // allocate buffer to store bias-corrected data
|
|
207
|
-
|
|
214
|
+
y_bias_corrected.resize(n * d_out);
|
|
208
215
|
const float* yr = y;
|
|
209
|
-
float* yw =
|
|
216
|
+
float* yw = y_bias_corrected.data();
|
|
210
217
|
for (idx_t i = 0; i < n; i++) {
|
|
211
218
|
for (int j = 0; j < d_out; j++) {
|
|
212
219
|
*yw++ = *yr++ - b[j];
|
|
213
220
|
}
|
|
214
221
|
}
|
|
215
|
-
y =
|
|
222
|
+
y = y_bias_corrected.data();
|
|
216
223
|
}
|
|
217
224
|
|
|
218
225
|
{
|
|
219
|
-
FINTEGER dii = d_in, doi = d_out, ni = n;
|
|
226
|
+
FINTEGER dii = d_in, doi = d_out, ni = static_cast<FINTEGER>(n);
|
|
220
227
|
float one = 1.0, zero = 0.0;
|
|
221
228
|
sgemm_("Not",
|
|
222
229
|
"Not",
|
|
@@ -232,9 +239,6 @@ void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
|
|
|
232
239
|
x,
|
|
233
240
|
&dii);
|
|
234
241
|
}
|
|
235
|
-
|
|
236
|
-
if (have_bias)
|
|
237
|
-
delete[] y;
|
|
238
242
|
}
|
|
239
243
|
|
|
240
244
|
void LinearTransform::set_is_orthonormal() {
|
|
@@ -249,7 +253,7 @@ void LinearTransform::set_is_orthonormal() {
|
|
|
249
253
|
}
|
|
250
254
|
|
|
251
255
|
double eps = 4e-5;
|
|
252
|
-
FAISS_ASSERT(A.size() >= d_out * d_in);
|
|
256
|
+
FAISS_ASSERT(A.size() >= static_cast<size_t>(d_out) * d_in);
|
|
253
257
|
{
|
|
254
258
|
std::vector<float> ATA(d_out * d_out);
|
|
255
259
|
FINTEGER dii = d_in, doi = d_out;
|
|
@@ -273,9 +277,10 @@ void LinearTransform::set_is_orthonormal() {
|
|
|
273
277
|
for (long i = 0; i < d_out; i++) {
|
|
274
278
|
for (long j = 0; j < d_out; j++) {
|
|
275
279
|
float v = ATA[i + j * d_out];
|
|
276
|
-
if (i == j)
|
|
280
|
+
if (i == j) {
|
|
277
281
|
v -= 1;
|
|
278
|
-
|
|
282
|
+
}
|
|
283
|
+
if (std::fabs(v) > eps) {
|
|
279
284
|
is_orthonormal = false;
|
|
280
285
|
}
|
|
281
286
|
}
|
|
@@ -298,10 +303,13 @@ void LinearTransform::print_if_verbose(
|
|
|
298
303
|
const std::vector<double>& mat,
|
|
299
304
|
int n,
|
|
300
305
|
int d) const {
|
|
301
|
-
if (!verbose)
|
|
306
|
+
if (!verbose) {
|
|
302
307
|
return;
|
|
308
|
+
}
|
|
303
309
|
printf("matrix %s: %d*%d [\n", name, n, d);
|
|
304
|
-
|
|
310
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
311
|
+
mat.size() >= static_cast<size_t>(n) * d,
|
|
312
|
+
"matrix size is too small for the given dimensions");
|
|
305
313
|
for (int i = 0; i < n; i++) {
|
|
306
314
|
for (int j = 0; j < d; j++) {
|
|
307
315
|
printf("%10.5g ", mat[i * d + j]);
|
|
@@ -314,8 +322,10 @@ void LinearTransform::print_if_verbose(
|
|
|
314
322
|
void LinearTransform::check_identical(const VectorTransform& other_in) const {
|
|
315
323
|
VectorTransform::check_identical(other_in);
|
|
316
324
|
auto other = dynamic_cast<const LinearTransform*>(&other_in);
|
|
317
|
-
|
|
318
|
-
|
|
325
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to LinearTransform");
|
|
326
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
327
|
+
other->A == A && other->b == b,
|
|
328
|
+
"LinearTransform matrix A and bias vector b must match");
|
|
319
329
|
}
|
|
320
330
|
|
|
321
331
|
/*********************************************
|
|
@@ -352,18 +362,149 @@ void RandomRotationMatrix::train(idx_t /*n*/, const float* /*x*/) {
|
|
|
352
362
|
init(12345);
|
|
353
363
|
}
|
|
354
364
|
|
|
365
|
+
/*********************************************
|
|
366
|
+
* HadamardRotation
|
|
367
|
+
*********************************************/
|
|
368
|
+
|
|
369
|
+
// In-place Fast Walsh-Hadamard Transform. n must be a power of 2.
|
|
370
|
+
// Applies the unnormalized Hadamard butterfly: O(n log n) add/sub, no
|
|
371
|
+
// multiplies.
|
|
372
|
+
static void fwht_inplace(float* buf, size_t n) {
|
|
373
|
+
for (size_t step = 1; step < n; step *= 2) {
|
|
374
|
+
for (size_t i = 0; i < n; i += step * 2) {
|
|
375
|
+
for (size_t j = i; j < i + step; j++) {
|
|
376
|
+
float a = buf[j];
|
|
377
|
+
float b = buf[j + step];
|
|
378
|
+
buf[j] = a + b;
|
|
379
|
+
buf[j + step] = a - b;
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
// Smallest power of 2 >= n.
|
|
386
|
+
static int next_power_of_2(int n) {
|
|
387
|
+
int p = 1;
|
|
388
|
+
while (p < n) {
|
|
389
|
+
p *= 2;
|
|
390
|
+
}
|
|
391
|
+
return p;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
// Generate three sign-flip vectors from the given seed.
|
|
395
|
+
static void generate_signs(
|
|
396
|
+
uint32_t seed,
|
|
397
|
+
size_t p,
|
|
398
|
+
std::vector<float>& s1,
|
|
399
|
+
std::vector<float>& s2,
|
|
400
|
+
std::vector<float>& s3) {
|
|
401
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
402
|
+
p > 0, "number of Hadamard factors p must be positive");
|
|
403
|
+
SplitMix64RandomGenerator rng(seed);
|
|
404
|
+
s1.resize(p);
|
|
405
|
+
s2.resize(p);
|
|
406
|
+
s3.resize(p);
|
|
407
|
+
for (size_t j = 0; j < p; j++) {
|
|
408
|
+
s1[j] = (rng.rand_int(2) == 0) ? -1.0f : 1.0f;
|
|
409
|
+
}
|
|
410
|
+
for (size_t j = 0; j < p; j++) {
|
|
411
|
+
s2[j] = (rng.rand_int(2) == 0) ? -1.0f : 1.0f;
|
|
412
|
+
}
|
|
413
|
+
for (size_t j = 0; j < p; j++) {
|
|
414
|
+
s3[j] = (rng.rand_int(2) == 0) ? -1.0f : 1.0f;
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
HadamardRotation::HadamardRotation(int d, uint32_t seed_in)
|
|
419
|
+
: VectorTransform(d, next_power_of_2(d)), seed(seed_in) {
|
|
420
|
+
init(seed_in);
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
void HadamardRotation::init(uint32_t seed_in) {
|
|
424
|
+
seed = seed_in;
|
|
425
|
+
is_trained = true;
|
|
426
|
+
generate_signs(seed, d_out, signs1, signs2, signs3);
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
void HadamardRotation::train(idx_t, const float*) {
|
|
430
|
+
init(seed != 0 ? seed : 12345);
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
void HadamardRotation::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
434
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
|
|
435
|
+
|
|
436
|
+
size_t d = d_in;
|
|
437
|
+
size_t p = d_out;
|
|
438
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
439
|
+
signs1.size() == p,
|
|
440
|
+
"sign-flip vector 1 size must match output dimension");
|
|
441
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
442
|
+
signs2.size() == p,
|
|
443
|
+
"sign-flip vector 2 size must match output dimension");
|
|
444
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
445
|
+
signs3.size() == p,
|
|
446
|
+
"sign-flip vector 3 size must match output dimension");
|
|
447
|
+
|
|
448
|
+
// Each unnormalized FWHT scales norms by sqrt(p).
|
|
449
|
+
// Three rounds scale by p^(3/2). Normalize once at the end.
|
|
450
|
+
float total_scale = 1.0f / (p * std::sqrt(static_cast<float>(p)));
|
|
451
|
+
|
|
452
|
+
#pragma omp parallel for schedule(dynamic)
|
|
453
|
+
for (idx_t i = 0; i < n; i++) {
|
|
454
|
+
const float* xi = x + i * d;
|
|
455
|
+
float* xo = xt + i * p;
|
|
456
|
+
|
|
457
|
+
// Round 1: copy + zero-pad + sign-flip + FWHT
|
|
458
|
+
for (size_t j = 0; j < d; j++) {
|
|
459
|
+
xo[j] = xi[j] * signs1[j];
|
|
460
|
+
}
|
|
461
|
+
for (size_t j = d; j < p; j++) {
|
|
462
|
+
xo[j] = 0.0f;
|
|
463
|
+
}
|
|
464
|
+
fwht_inplace(xo, p);
|
|
465
|
+
|
|
466
|
+
// Round 2: sign-flip + FWHT
|
|
467
|
+
for (size_t j = 0; j < p; j++) {
|
|
468
|
+
xo[j] *= signs2[j];
|
|
469
|
+
}
|
|
470
|
+
fwht_inplace(xo, p);
|
|
471
|
+
|
|
472
|
+
// Round 3: sign-flip + FWHT + normalize
|
|
473
|
+
for (size_t j = 0; j < p; j++) {
|
|
474
|
+
xo[j] *= signs3[j];
|
|
475
|
+
}
|
|
476
|
+
fwht_inplace(xo, p);
|
|
477
|
+
|
|
478
|
+
for (size_t j = 0; j < p; j++) {
|
|
479
|
+
xo[j] *= total_scale;
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
void HadamardRotation::check_identical(const VectorTransform& other) const {
|
|
485
|
+
auto* hr = dynamic_cast<const HadamardRotation*>(&other);
|
|
486
|
+
FAISS_THROW_IF_NOT_MSG(hr, "failed to cast to HadamardRotation");
|
|
487
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
488
|
+
d_in == hr->d_in, "HadamardRotation input dimensions must match");
|
|
489
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
490
|
+
d_out == hr->d_out,
|
|
491
|
+
"HadamardRotation output dimensions must match");
|
|
492
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
493
|
+
seed == hr->seed, "HadamardRotation seeds must match");
|
|
494
|
+
}
|
|
495
|
+
|
|
355
496
|
/*********************************************
|
|
356
497
|
* PCAMatrix
|
|
357
498
|
*********************************************/
|
|
358
499
|
|
|
359
500
|
PCAMatrix::PCAMatrix(
|
|
360
|
-
int
|
|
361
|
-
int
|
|
362
|
-
float
|
|
363
|
-
bool
|
|
364
|
-
: LinearTransform(
|
|
365
|
-
eigen_power(
|
|
366
|
-
random_rotation(
|
|
501
|
+
int din,
|
|
502
|
+
int dout,
|
|
503
|
+
float eigen_power_in,
|
|
504
|
+
bool random_rotation_in)
|
|
505
|
+
: LinearTransform(din, dout, true),
|
|
506
|
+
eigen_power(eigen_power_in),
|
|
507
|
+
random_rotation(random_rotation_in) {
|
|
367
508
|
is_trained = false;
|
|
368
509
|
max_points_per_d = 1000;
|
|
369
510
|
balanced_bins = 0;
|
|
@@ -377,7 +518,7 @@ namespace {
|
|
|
377
518
|
|
|
378
519
|
void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
379
520
|
{ // compute eigenvalues and vectors
|
|
380
|
-
FINTEGER info = 0, lwork = -1, di = d_in;
|
|
521
|
+
FINTEGER info = 0, lwork = -1, di = static_cast<FINTEGER>(d_in);
|
|
381
522
|
double workq;
|
|
382
523
|
|
|
383
524
|
dsyev_("Vectors as well",
|
|
@@ -389,8 +530,8 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
389
530
|
&workq,
|
|
390
531
|
&lwork,
|
|
391
532
|
&info);
|
|
392
|
-
lwork = FINTEGER(workq);
|
|
393
|
-
double
|
|
533
|
+
lwork = static_cast<FINTEGER>(workq);
|
|
534
|
+
std::vector<double> work(lwork);
|
|
394
535
|
|
|
395
536
|
dsyev_("Vectors as well",
|
|
396
537
|
"Upper",
|
|
@@ -398,12 +539,10 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
398
539
|
cov,
|
|
399
540
|
&di,
|
|
400
541
|
eigenvalues,
|
|
401
|
-
work,
|
|
542
|
+
work.data(),
|
|
402
543
|
&lwork,
|
|
403
544
|
&info);
|
|
404
545
|
|
|
405
|
-
delete[] work;
|
|
406
|
-
|
|
407
546
|
if (info != 0) {
|
|
408
547
|
fprintf(stderr,
|
|
409
548
|
"WARN ssyev info returns %d, "
|
|
@@ -414,15 +553,17 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
414
553
|
|
|
415
554
|
if (verbose && d_in <= 10) {
|
|
416
555
|
printf("info=%ld new eigvals=[", long(info));
|
|
417
|
-
for (
|
|
556
|
+
for (size_t j = 0; j < d_in; j++) {
|
|
418
557
|
printf("%g ", eigenvalues[j]);
|
|
558
|
+
}
|
|
419
559
|
printf("]\n");
|
|
420
560
|
|
|
421
561
|
double* ci = cov;
|
|
422
562
|
printf("eigenvecs=\n");
|
|
423
|
-
for (
|
|
424
|
-
for (
|
|
563
|
+
for (size_t i = 0; i < d_in; i++) {
|
|
564
|
+
for (size_t j = 0; j < d_in; j++) {
|
|
425
565
|
printf("%10.4g ", *ci++);
|
|
566
|
+
}
|
|
426
567
|
printf("\n");
|
|
427
568
|
}
|
|
428
569
|
}
|
|
@@ -430,12 +571,13 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
430
571
|
|
|
431
572
|
// revert order of eigenvectors & values
|
|
432
573
|
|
|
433
|
-
for (
|
|
574
|
+
for (size_t i = 0; i < d_in / 2; i++) {
|
|
434
575
|
std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
|
|
435
576
|
double* v1 = cov + i * d_in;
|
|
436
577
|
double* v2 = cov + (d_in - 1 - i) * d_in;
|
|
437
|
-
for (
|
|
578
|
+
for (size_t j = 0; j < d_in; j++) {
|
|
438
579
|
std::swap(v1[j], v2[j]);
|
|
580
|
+
}
|
|
439
581
|
}
|
|
440
582
|
}
|
|
441
583
|
|
|
@@ -451,17 +593,20 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
451
593
|
mean.resize(d_in, 0.0);
|
|
452
594
|
if (have_bias) { // we may want to skip the bias
|
|
453
595
|
const float* xi = x;
|
|
454
|
-
for (
|
|
455
|
-
for (int j = 0; j < d_in; j++)
|
|
596
|
+
for (idx_t i = 0; i < n; i++) {
|
|
597
|
+
for (int j = 0; j < d_in; j++) {
|
|
456
598
|
mean[j] += *xi++;
|
|
599
|
+
}
|
|
457
600
|
}
|
|
458
|
-
for (int j = 0; j < d_in; j++)
|
|
601
|
+
for (int j = 0; j < d_in; j++) {
|
|
459
602
|
mean[j] /= n;
|
|
603
|
+
}
|
|
460
604
|
}
|
|
461
605
|
if (verbose) {
|
|
462
606
|
printf("mean=[");
|
|
463
|
-
for (int j = 0; j < d_in; j++)
|
|
607
|
+
for (int j = 0; j < d_in; j++) {
|
|
464
608
|
printf("%g ", mean[j]);
|
|
609
|
+
}
|
|
465
610
|
printf("]\n");
|
|
466
611
|
}
|
|
467
612
|
|
|
@@ -472,12 +617,13 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
472
617
|
{ // initialize with mean * mean^T term
|
|
473
618
|
float* ci = cov;
|
|
474
619
|
for (int i = 0; i < d_in; i++) {
|
|
475
|
-
for (int j = 0; j < d_in; j++)
|
|
620
|
+
for (int j = 0; j < d_in; j++) {
|
|
476
621
|
*ci++ = -n * mean[i] * mean[j];
|
|
622
|
+
}
|
|
477
623
|
}
|
|
478
624
|
}
|
|
479
625
|
{
|
|
480
|
-
FINTEGER di = d_in, ni = n;
|
|
626
|
+
FINTEGER di = d_in, ni = static_cast<FINTEGER>(n);
|
|
481
627
|
float one = 1.0;
|
|
482
628
|
ssyrk_("Up",
|
|
483
629
|
"Non transposed",
|
|
@@ -494,38 +640,44 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
494
640
|
float* ci = cov;
|
|
495
641
|
printf("cov=\n");
|
|
496
642
|
for (int i = 0; i < d_in; i++) {
|
|
497
|
-
for (int j = 0; j < d_in; j++)
|
|
643
|
+
for (int j = 0; j < d_in; j++) {
|
|
498
644
|
printf("%10g ", *ci++);
|
|
645
|
+
}
|
|
499
646
|
printf("\n");
|
|
500
647
|
}
|
|
501
648
|
}
|
|
502
649
|
|
|
503
650
|
std::vector<double> covd(d_in * d_in);
|
|
504
|
-
for (size_t i = 0; i < d_in * d_in; i++)
|
|
651
|
+
for (size_t i = 0; i < d_in * d_in; i++) {
|
|
505
652
|
covd[i] = cov[i];
|
|
653
|
+
}
|
|
506
654
|
|
|
507
655
|
std::vector<double> eigenvaluesd(d_in);
|
|
508
656
|
|
|
509
657
|
eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
|
|
510
658
|
|
|
511
|
-
for (size_t i = 0; i < d_in * d_in; i++)
|
|
659
|
+
for (size_t i = 0; i < d_in * d_in; i++) {
|
|
512
660
|
PCAMat[i] = covd[i];
|
|
661
|
+
}
|
|
513
662
|
eigenvalues.resize(d_in);
|
|
514
663
|
|
|
515
|
-
for (
|
|
664
|
+
for (int i = 0; i < d_in; i++) {
|
|
516
665
|
eigenvalues[i] = eigenvaluesd[i];
|
|
666
|
+
}
|
|
517
667
|
|
|
518
668
|
} else {
|
|
519
669
|
std::vector<float> xc(n * d_in);
|
|
520
670
|
|
|
521
|
-
for (
|
|
522
|
-
for (
|
|
671
|
+
for (idx_t i = 0; i < n; i++) {
|
|
672
|
+
for (int j = 0; j < d_in; j++) {
|
|
523
673
|
xc[i * d_in + j] = x[i * d_in + j] - mean[j];
|
|
674
|
+
}
|
|
675
|
+
}
|
|
524
676
|
|
|
525
677
|
// compute Gram matrix
|
|
526
678
|
std::vector<float> gram(n * n);
|
|
527
679
|
{
|
|
528
|
-
FINTEGER di = d_in, ni = n;
|
|
680
|
+
FINTEGER di = d_in, ni = static_cast<FINTEGER>(n);
|
|
529
681
|
float one = 1.0, zero = 0.0;
|
|
530
682
|
ssyrk_("Up",
|
|
531
683
|
"Transposed",
|
|
@@ -542,16 +694,18 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
542
694
|
if (verbose && d_in <= 10) {
|
|
543
695
|
float* ci = gram.data();
|
|
544
696
|
printf("gram=\n");
|
|
545
|
-
for (
|
|
546
|
-
for (
|
|
697
|
+
for (idx_t i = 0; i < n; i++) {
|
|
698
|
+
for (idx_t j = 0; j < n; j++) {
|
|
547
699
|
printf("%10g ", *ci++);
|
|
700
|
+
}
|
|
548
701
|
printf("\n");
|
|
549
702
|
}
|
|
550
703
|
}
|
|
551
704
|
|
|
552
705
|
std::vector<double> gramd(n * n);
|
|
553
|
-
for (size_t i = 0; i < n * n; i++)
|
|
706
|
+
for (size_t i = 0; i < n * n; i++) {
|
|
554
707
|
gramd[i] = gram[i];
|
|
708
|
+
}
|
|
555
709
|
|
|
556
710
|
std::vector<double> eigenvaluesd(n);
|
|
557
711
|
|
|
@@ -561,17 +715,19 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
561
715
|
|
|
562
716
|
PCAMat.resize(d_in * n);
|
|
563
717
|
|
|
564
|
-
for (size_t i = 0; i < n * n; i++)
|
|
718
|
+
for (size_t i = 0; i < n * n; i++) {
|
|
565
719
|
gram[i] = gramd[i];
|
|
720
|
+
}
|
|
566
721
|
|
|
567
722
|
eigenvalues.resize(d_in);
|
|
568
723
|
// fill in only the n first ones
|
|
569
|
-
for (
|
|
724
|
+
for (idx_t i = 0; i < n; i++) {
|
|
570
725
|
eigenvalues[i] = eigenvaluesd[i];
|
|
726
|
+
}
|
|
571
727
|
|
|
572
728
|
{ // compute PCAMat = x' * v
|
|
573
|
-
FINTEGER di = d_in, ni = n;
|
|
574
|
-
float one = 1.0;
|
|
729
|
+
FINTEGER di = d_in, ni = static_cast<FINTEGER>(n);
|
|
730
|
+
float one = 1.0, zero = 0.0;
|
|
575
731
|
|
|
576
732
|
sgemm_("Non",
|
|
577
733
|
"Non Trans",
|
|
@@ -583,7 +739,7 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
583
739
|
&di,
|
|
584
740
|
gram.data(),
|
|
585
741
|
&ni,
|
|
586
|
-
&
|
|
742
|
+
&zero,
|
|
587
743
|
PCAMat.data(),
|
|
588
744
|
&di);
|
|
589
745
|
}
|
|
@@ -591,9 +747,10 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
591
747
|
if (verbose && d_in <= 10) {
|
|
592
748
|
float* ci = PCAMat.data();
|
|
593
749
|
printf("PCAMat=\n");
|
|
594
|
-
for (
|
|
595
|
-
for (int j = 0; j < d_in; j++)
|
|
750
|
+
for (idx_t i = 0; i < n; i++) {
|
|
751
|
+
for (int j = 0; j < d_in; j++) {
|
|
596
752
|
printf("%10g ", *ci++);
|
|
753
|
+
}
|
|
597
754
|
printf("\n");
|
|
598
755
|
}
|
|
599
756
|
}
|
|
@@ -605,7 +762,9 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
605
762
|
}
|
|
606
763
|
|
|
607
764
|
void PCAMatrix::copy_from(const PCAMatrix& other) {
|
|
608
|
-
|
|
765
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
766
|
+
other.is_trained,
|
|
767
|
+
"source PCAMatrix must be trained before copying");
|
|
609
768
|
mean = other.mean;
|
|
610
769
|
eigenvalues = other.eigenvalues;
|
|
611
770
|
PCAMat = other.PCAMat;
|
|
@@ -615,7 +774,7 @@ void PCAMatrix::copy_from(const PCAMatrix& other) {
|
|
|
615
774
|
|
|
616
775
|
void PCAMatrix::prepare_Ab() {
|
|
617
776
|
FAISS_THROW_IF_NOT_FMT(
|
|
618
|
-
d_out * d_in <= PCAMat.size(),
|
|
777
|
+
static_cast<size_t>(d_out) * d_in <= PCAMat.size(),
|
|
619
778
|
"PCA matrix cannot output %d dimensions from %d ",
|
|
620
779
|
d_out,
|
|
621
780
|
d_in);
|
|
@@ -628,14 +787,17 @@ void PCAMatrix::prepare_Ab() {
|
|
|
628
787
|
if (eigen_power != 0) {
|
|
629
788
|
float* ai = A.data();
|
|
630
789
|
for (int i = 0; i < d_out; i++) {
|
|
631
|
-
float factor = pow(eigenvalues[i] + epsilon, eigen_power);
|
|
632
|
-
for (int j = 0; j < d_in; j++)
|
|
790
|
+
float factor = std::pow(eigenvalues[i] + epsilon, eigen_power);
|
|
791
|
+
for (int j = 0; j < d_in; j++) {
|
|
633
792
|
*ai++ *= factor;
|
|
793
|
+
}
|
|
634
794
|
}
|
|
635
795
|
}
|
|
636
796
|
|
|
637
797
|
if (balanced_bins != 0) {
|
|
638
|
-
|
|
798
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
799
|
+
d_out % balanced_bins == 0,
|
|
800
|
+
"output dimension must be divisible by balanced_bins");
|
|
639
801
|
int dsub = d_out / balanced_bins;
|
|
640
802
|
std::vector<float> Ain;
|
|
641
803
|
std::swap(A, Ain);
|
|
@@ -663,8 +825,9 @@ void PCAMatrix::prepare_Ab() {
|
|
|
663
825
|
|
|
664
826
|
if (verbose) {
|
|
665
827
|
printf(" bin accu=[");
|
|
666
|
-
for (int i = 0; i < balanced_bins; i++)
|
|
828
|
+
for (int i = 0; i < balanced_bins; i++) {
|
|
667
829
|
printf("%g ", accu[i]);
|
|
830
|
+
}
|
|
668
831
|
printf("]\n");
|
|
669
832
|
}
|
|
670
833
|
}
|
|
@@ -682,8 +845,9 @@ void PCAMatrix::prepare_Ab() {
|
|
|
682
845
|
if (eigen_power != 0) {
|
|
683
846
|
for (int i = 0; i < d_out; i++) {
|
|
684
847
|
float factor = pow(eigenvalues[i], eigen_power);
|
|
685
|
-
for (int j = 0; j < d_out; j++)
|
|
848
|
+
for (int j = 0; j < d_out; j++) {
|
|
686
849
|
rr.A[j * d_out + i] *= factor;
|
|
850
|
+
}
|
|
687
851
|
}
|
|
688
852
|
}
|
|
689
853
|
|
|
@@ -713,8 +877,9 @@ void PCAMatrix::prepare_Ab() {
|
|
|
713
877
|
|
|
714
878
|
for (int i = 0; i < d_out; i++) {
|
|
715
879
|
float accu = 0;
|
|
716
|
-
for (int j = 0; j < d_in; j++)
|
|
880
|
+
for (int j = 0; j < d_in; j++) {
|
|
717
881
|
accu -= mean[j] * A[j + i * d_in];
|
|
882
|
+
}
|
|
718
883
|
b[i] = accu;
|
|
719
884
|
}
|
|
720
885
|
|
|
@@ -738,7 +903,7 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
738
903
|
init_rotation.data(),
|
|
739
904
|
d * d * sizeof(rotation[0]));
|
|
740
905
|
} else {
|
|
741
|
-
RandomRotationMatrix rrot(d, d);
|
|
906
|
+
RandomRotationMatrix rrot(static_cast<int>(d), static_cast<int>(d));
|
|
742
907
|
rrot.init(seed);
|
|
743
908
|
for (size_t i = 0; i < d * d; i++) {
|
|
744
909
|
rotation[i] = rrot.A[i];
|
|
@@ -755,9 +920,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
755
920
|
std::vector<double> u(d * d), vt(d * d), singvals(d);
|
|
756
921
|
|
|
757
922
|
for (int i = 0; i < max_iter; i++) {
|
|
758
|
-
print_if_verbose(
|
|
923
|
+
print_if_verbose(
|
|
924
|
+
"rotation", rotation, static_cast<int>(d), static_cast<int>(d));
|
|
759
925
|
{ // rotated_data = np.dot(training_data, rotation)
|
|
760
|
-
FINTEGER di = d,
|
|
926
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
927
|
+
ni = static_cast<FINTEGER>(n);
|
|
761
928
|
double one = 1, zero = 0;
|
|
762
929
|
dgemm_("N",
|
|
763
930
|
"N",
|
|
@@ -773,14 +940,19 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
773
940
|
rotated_x.data(),
|
|
774
941
|
&di);
|
|
775
942
|
}
|
|
776
|
-
print_if_verbose(
|
|
943
|
+
print_if_verbose(
|
|
944
|
+
"rotated_x",
|
|
945
|
+
rotated_x,
|
|
946
|
+
static_cast<int>(n),
|
|
947
|
+
static_cast<int>(d));
|
|
777
948
|
// binarize
|
|
778
949
|
for (size_t j = 0; j < n * d; j++) {
|
|
779
950
|
rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
|
|
780
951
|
}
|
|
781
952
|
// covariance matrix
|
|
782
953
|
{ // rotated_data = np.dot(training_data, rotation)
|
|
783
|
-
FINTEGER di = d,
|
|
954
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
955
|
+
ni = static_cast<FINTEGER>(n);
|
|
784
956
|
double one = 1, zero = 0;
|
|
785
957
|
dgemm_("N",
|
|
786
958
|
"T",
|
|
@@ -796,10 +968,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
796
968
|
cov_mat.data(),
|
|
797
969
|
&di);
|
|
798
970
|
}
|
|
799
|
-
print_if_verbose(
|
|
971
|
+
print_if_verbose(
|
|
972
|
+
"cov_mat", cov_mat, static_cast<int>(d), static_cast<int>(d));
|
|
800
973
|
// SVD
|
|
801
974
|
{
|
|
802
|
-
FINTEGER di = d;
|
|
975
|
+
FINTEGER di = static_cast<FINTEGER>(d);
|
|
803
976
|
FINTEGER lwork = -1, info;
|
|
804
977
|
double lwork1;
|
|
805
978
|
|
|
@@ -819,8 +992,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
819
992
|
&lwork,
|
|
820
993
|
&info);
|
|
821
994
|
|
|
822
|
-
|
|
823
|
-
|
|
995
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
996
|
+
info == 0,
|
|
997
|
+
"LAPACK dgesvd workspace query returned info=%d",
|
|
998
|
+
int(info));
|
|
999
|
+
lwork = static_cast<FINTEGER>(lwork1);
|
|
824
1000
|
std::vector<double> work(lwork);
|
|
825
1001
|
dgesvd_("A",
|
|
826
1002
|
"A",
|
|
@@ -838,11 +1014,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
838
1014
|
&info);
|
|
839
1015
|
FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
|
|
840
1016
|
}
|
|
841
|
-
print_if_verbose("u", u, d, d);
|
|
842
|
-
print_if_verbose("vt", vt, d, d);
|
|
1017
|
+
print_if_verbose("u", u, static_cast<int>(d), static_cast<int>(d));
|
|
1018
|
+
print_if_verbose("vt", vt, static_cast<int>(d), static_cast<int>(d));
|
|
843
1019
|
// update rotation
|
|
844
1020
|
{
|
|
845
|
-
FINTEGER di = d;
|
|
1021
|
+
FINTEGER di = static_cast<FINTEGER>(d);
|
|
846
1022
|
double one = 1, zero = 0;
|
|
847
1023
|
dgemm_("N",
|
|
848
1024
|
"T",
|
|
@@ -858,7 +1034,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
858
1034
|
rotation.data(),
|
|
859
1035
|
&di);
|
|
860
1036
|
}
|
|
861
|
-
print_if_verbose(
|
|
1037
|
+
print_if_verbose(
|
|
1038
|
+
"final rot",
|
|
1039
|
+
rotation,
|
|
1040
|
+
static_cast<int>(d),
|
|
1041
|
+
static_cast<int>(d));
|
|
862
1042
|
}
|
|
863
1043
|
A.resize(d * d);
|
|
864
1044
|
for (size_t i = 0; i < d; i++) {
|
|
@@ -869,20 +1049,23 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
869
1049
|
is_trained = true;
|
|
870
1050
|
}
|
|
871
1051
|
|
|
872
|
-
ITQTransform::ITQTransform(int
|
|
873
|
-
: VectorTransform(
|
|
874
|
-
do_pca(
|
|
875
|
-
itq(
|
|
876
|
-
pca_then_itq(
|
|
877
|
-
if (!
|
|
878
|
-
|
|
1052
|
+
ITQTransform::ITQTransform(int din, int dout, bool do_pca_in)
|
|
1053
|
+
: VectorTransform(din, dout),
|
|
1054
|
+
do_pca(do_pca_in),
|
|
1055
|
+
itq(dout),
|
|
1056
|
+
pca_then_itq(din, dout, false) {
|
|
1057
|
+
if (!do_pca_in) {
|
|
1058
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1059
|
+
din == dout,
|
|
1060
|
+
"input and output dimensions must match when PCA is disabled");
|
|
879
1061
|
}
|
|
880
1062
|
max_train_per_dim = 10;
|
|
881
1063
|
is_trained = false;
|
|
882
1064
|
}
|
|
883
1065
|
|
|
884
1066
|
void ITQTransform::train(idx_t n, const float* x_in) {
|
|
885
|
-
|
|
1067
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1068
|
+
!is_trained, "ITQTransform has already been trained");
|
|
886
1069
|
|
|
887
1070
|
size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
|
|
888
1071
|
const float* x =
|
|
@@ -974,17 +1157,18 @@ void ITQTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
974
1157
|
void ITQTransform::check_identical(const VectorTransform& other_in) const {
|
|
975
1158
|
VectorTransform::check_identical(other_in);
|
|
976
1159
|
auto other = dynamic_cast<const ITQTransform*>(&other_in);
|
|
977
|
-
|
|
1160
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to ITQTransform");
|
|
978
1161
|
pca_then_itq.check_identical(other->pca_then_itq);
|
|
979
|
-
|
|
1162
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1163
|
+
other->mean == mean, "ITQTransform mean vectors must match");
|
|
980
1164
|
}
|
|
981
1165
|
|
|
982
1166
|
/*********************************************
|
|
983
1167
|
* OPQMatrix
|
|
984
1168
|
*********************************************/
|
|
985
1169
|
|
|
986
|
-
OPQMatrix::OPQMatrix(int d, int
|
|
987
|
-
: LinearTransform(d, d2 == -1 ? d : d2, false), M(
|
|
1170
|
+
OPQMatrix::OPQMatrix(int d, int M_in, int d2)
|
|
1171
|
+
: LinearTransform(d, d2 == -1 ? d : d2, false), M(M_in) {
|
|
988
1172
|
is_trained = false;
|
|
989
1173
|
// OPQ is quite expensive to train, so set this right.
|
|
990
1174
|
max_train_points = 256 * 256;
|
|
@@ -1030,17 +1214,20 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1030
1214
|
{
|
|
1031
1215
|
std::vector<float> sum(d);
|
|
1032
1216
|
const float* xi = x;
|
|
1033
|
-
for (
|
|
1034
|
-
for (int j = 0; j < d_in; j++)
|
|
1217
|
+
for (idx_t i = 0; i < n; i++) {
|
|
1218
|
+
for (int j = 0; j < d_in; j++) {
|
|
1035
1219
|
sum[j] += *xi++;
|
|
1220
|
+
}
|
|
1036
1221
|
}
|
|
1037
|
-
for (
|
|
1222
|
+
for (size_t i = 0; i < d; i++) {
|
|
1038
1223
|
sum[i] /= n;
|
|
1224
|
+
}
|
|
1039
1225
|
float* yi = xtrain.data();
|
|
1040
1226
|
xi = x;
|
|
1041
|
-
for (
|
|
1042
|
-
for (int j = 0; j < d_in; j++)
|
|
1227
|
+
for (idx_t i = 0; i < n; i++) {
|
|
1228
|
+
for (int j = 0; j < d_in; j++) {
|
|
1043
1229
|
*yi++ = *xi++ - sum[j];
|
|
1230
|
+
}
|
|
1044
1231
|
yi += d - d_in;
|
|
1045
1232
|
}
|
|
1046
1233
|
}
|
|
@@ -1049,16 +1236,18 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1049
1236
|
if (A.size() == 0) {
|
|
1050
1237
|
A.resize(d * d);
|
|
1051
1238
|
rotation = A.data();
|
|
1052
|
-
if (verbose)
|
|
1239
|
+
if (verbose) {
|
|
1053
1240
|
printf(" OPQMatrix::train: making random %zd*%zd rotation\n",
|
|
1054
1241
|
d,
|
|
1055
1242
|
d);
|
|
1243
|
+
}
|
|
1056
1244
|
float_randn(rotation, d * d, 1234);
|
|
1057
1245
|
matrix_qr(d, d, rotation);
|
|
1058
1246
|
// we use only the d * d2 upper part of the matrix
|
|
1059
1247
|
A.resize(d * d2);
|
|
1060
1248
|
} else {
|
|
1061
|
-
|
|
1249
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1250
|
+
A.size() == d * d2, "rotation matrix A has incorrect size");
|
|
1062
1251
|
rotation = A.data();
|
|
1063
1252
|
}
|
|
1064
1253
|
|
|
@@ -1072,7 +1261,9 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1072
1261
|
double t0 = getmillisecs();
|
|
1073
1262
|
for (int iter = 0; iter < niter; iter++) {
|
|
1074
1263
|
{ // torch.mm(xtrain, rotation:t())
|
|
1075
|
-
FINTEGER di = d,
|
|
1264
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
1265
|
+
d2i = static_cast<FINTEGER>(d2),
|
|
1266
|
+
ni = static_cast<FINTEGER>(n);
|
|
1076
1267
|
float zero = 0, one = 1;
|
|
1077
1268
|
sgemm_("Transposed",
|
|
1078
1269
|
"Not transposed",
|
|
@@ -1107,18 +1298,21 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1107
1298
|
|
|
1108
1299
|
float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
|
|
1109
1300
|
|
|
1110
|
-
if (verbose)
|
|
1301
|
+
if (verbose) {
|
|
1111
1302
|
printf(" Iteration %d (%d PQ iterations):"
|
|
1112
1303
|
"%.3f s, obj=%g\n",
|
|
1113
1304
|
iter,
|
|
1114
1305
|
pq_regular.cp.niter,
|
|
1115
1306
|
(getmillisecs() - t0) / 1000.0,
|
|
1116
1307
|
pq_err);
|
|
1308
|
+
}
|
|
1117
1309
|
|
|
1118
1310
|
{
|
|
1119
1311
|
float *u = tmp.data(), *vt = &tmp[d * d];
|
|
1120
1312
|
float* sing_val = &tmp[2 * d * d];
|
|
1121
|
-
FINTEGER di = d,
|
|
1313
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
1314
|
+
d2i = static_cast<FINTEGER>(d2),
|
|
1315
|
+
ni = static_cast<FINTEGER>(n);
|
|
1122
1316
|
float one = 1, zero = 0;
|
|
1123
1317
|
|
|
1124
1318
|
if (verbose) {
|
|
@@ -1157,7 +1351,11 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1157
1351
|
&lwork,
|
|
1158
1352
|
&info);
|
|
1159
1353
|
|
|
1160
|
-
|
|
1354
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
1355
|
+
info == 0,
|
|
1356
|
+
"LAPACK sgesvd workspace query returned info=%d",
|
|
1357
|
+
int(info));
|
|
1358
|
+
lwork = static_cast<FINTEGER>(worksz);
|
|
1161
1359
|
std::vector<float> work(lwork);
|
|
1162
1360
|
// u and vt swapped
|
|
1163
1361
|
sgesvd_("All",
|
|
@@ -1193,9 +1391,10 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1193
1391
|
}
|
|
1194
1392
|
|
|
1195
1393
|
// revert A matrix
|
|
1196
|
-
if (d > d_in) {
|
|
1197
|
-
for (long i = 0; i < d_out; i++)
|
|
1394
|
+
if (d > static_cast<size_t>(d_in)) {
|
|
1395
|
+
for (long i = 0; i < d_out; i++) {
|
|
1198
1396
|
memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
|
|
1397
|
+
}
|
|
1199
1398
|
A.resize(d_in * d_out);
|
|
1200
1399
|
}
|
|
1201
1400
|
|
|
@@ -1207,8 +1406,8 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1207
1406
|
* NormalizationTransform
|
|
1208
1407
|
*********************************************/
|
|
1209
1408
|
|
|
1210
|
-
NormalizationTransform::NormalizationTransform(int d, float
|
|
1211
|
-
: VectorTransform(d, d), norm(
|
|
1409
|
+
NormalizationTransform::NormalizationTransform(int d, float norm_in)
|
|
1410
|
+
: VectorTransform(d, d), norm(norm_in) {}
|
|
1212
1411
|
|
|
1213
1412
|
NormalizationTransform::NormalizationTransform()
|
|
1214
1413
|
: VectorTransform(-1, -1), norm(-1) {}
|
|
@@ -1234,8 +1433,9 @@ void NormalizationTransform::check_identical(
|
|
|
1234
1433
|
const VectorTransform& other_in) const {
|
|
1235
1434
|
VectorTransform::check_identical(other_in);
|
|
1236
1435
|
auto other = dynamic_cast<const NormalizationTransform*>(&other_in);
|
|
1237
|
-
|
|
1238
|
-
|
|
1436
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to NormalizationTransform");
|
|
1437
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1438
|
+
other->norm == norm, "normalization type must match");
|
|
1239
1439
|
}
|
|
1240
1440
|
|
|
1241
1441
|
/*********************************************
|
|
@@ -1250,12 +1450,12 @@ void CenteringTransform::train(idx_t n, const float* x) {
|
|
|
1250
1450
|
FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
|
|
1251
1451
|
mean.resize(d_in, 0);
|
|
1252
1452
|
for (idx_t i = 0; i < n; i++) {
|
|
1253
|
-
for (
|
|
1453
|
+
for (int j = 0; j < d_in; j++) {
|
|
1254
1454
|
mean[j] += *x++;
|
|
1255
1455
|
}
|
|
1256
1456
|
}
|
|
1257
1457
|
|
|
1258
|
-
for (
|
|
1458
|
+
for (int j = 0; j < d_in; j++) {
|
|
1259
1459
|
mean[j] /= n;
|
|
1260
1460
|
}
|
|
1261
1461
|
is_trained = true;
|
|
@@ -1263,10 +1463,11 @@ void CenteringTransform::train(idx_t n, const float* x) {
|
|
|
1263
1463
|
|
|
1264
1464
|
void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
|
|
1265
1465
|
const {
|
|
1266
|
-
|
|
1466
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1467
|
+
is_trained, "CenteringTransform has not been trained");
|
|
1267
1468
|
|
|
1268
1469
|
for (idx_t i = 0; i < n; i++) {
|
|
1269
|
-
for (
|
|
1470
|
+
for (int j = 0; j < d_in; j++) {
|
|
1270
1471
|
*xt++ = *x++ - mean[j];
|
|
1271
1472
|
}
|
|
1272
1473
|
}
|
|
@@ -1274,10 +1475,11 @@ void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
|
|
|
1274
1475
|
|
|
1275
1476
|
void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
|
|
1276
1477
|
const {
|
|
1277
|
-
|
|
1478
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1479
|
+
is_trained, "CenteringTransform has not been trained");
|
|
1278
1480
|
|
|
1279
1481
|
for (idx_t i = 0; i < n; i++) {
|
|
1280
|
-
for (
|
|
1482
|
+
for (int j = 0; j < d_in; j++) {
|
|
1281
1483
|
*x++ = *xt++ + mean[j];
|
|
1282
1484
|
}
|
|
1283
1485
|
}
|
|
@@ -1287,8 +1489,9 @@ void CenteringTransform::check_identical(
|
|
|
1287
1489
|
const VectorTransform& other_in) const {
|
|
1288
1490
|
VectorTransform::check_identical(other_in);
|
|
1289
1491
|
auto other = dynamic_cast<const CenteringTransform*>(&other_in);
|
|
1290
|
-
|
|
1291
|
-
|
|
1492
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to CenteringTransform");
|
|
1493
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1494
|
+
other->mean == mean, "CenteringTransform mean vectors must match");
|
|
1292
1495
|
}
|
|
1293
1496
|
|
|
1294
1497
|
/*********************************************
|
|
@@ -1296,37 +1499,40 @@ void CenteringTransform::check_identical(
|
|
|
1296
1499
|
*********************************************/
|
|
1297
1500
|
|
|
1298
1501
|
RemapDimensionsTransform::RemapDimensionsTransform(
|
|
1299
|
-
int
|
|
1300
|
-
int
|
|
1502
|
+
int din,
|
|
1503
|
+
int dout,
|
|
1301
1504
|
const int* map_in)
|
|
1302
|
-
: VectorTransform(
|
|
1303
|
-
map.resize(
|
|
1304
|
-
for (int i = 0; i <
|
|
1505
|
+
: VectorTransform(din, dout) {
|
|
1506
|
+
map.resize(dout);
|
|
1507
|
+
for (int i = 0; i < dout; i++) {
|
|
1305
1508
|
map[i] = map_in[i];
|
|
1306
|
-
|
|
1509
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1510
|
+
map[i] == -1 || (map[i] >= 0 && map[i] < din),
|
|
1511
|
+
"map entries must be -1 (unused) or valid input dimension indices");
|
|
1307
1512
|
}
|
|
1308
1513
|
}
|
|
1309
1514
|
|
|
1310
1515
|
RemapDimensionsTransform::RemapDimensionsTransform(
|
|
1311
|
-
int
|
|
1312
|
-
int
|
|
1516
|
+
int din,
|
|
1517
|
+
int dout,
|
|
1313
1518
|
bool uniform)
|
|
1314
|
-
: VectorTransform(
|
|
1315
|
-
map.resize(
|
|
1519
|
+
: VectorTransform(din, dout) {
|
|
1520
|
+
map.resize(dout, -1);
|
|
1316
1521
|
|
|
1317
1522
|
if (uniform) {
|
|
1318
|
-
if (
|
|
1319
|
-
for (int i = 0; i <
|
|
1320
|
-
map[i *
|
|
1523
|
+
if (din < dout) {
|
|
1524
|
+
for (int i = 0; i < din; i++) {
|
|
1525
|
+
map[i * dout / din] = i;
|
|
1321
1526
|
}
|
|
1322
1527
|
} else {
|
|
1323
|
-
for (int i = 0; i <
|
|
1324
|
-
map[i] = i *
|
|
1528
|
+
for (int i = 0; i < dout; i++) {
|
|
1529
|
+
map[i] = i * din / dout;
|
|
1325
1530
|
}
|
|
1326
1531
|
}
|
|
1327
1532
|
} else {
|
|
1328
|
-
for (int i = 0; i <
|
|
1533
|
+
for (int i = 0; i < din && i < dout; i++) {
|
|
1329
1534
|
map[i] = i;
|
|
1535
|
+
}
|
|
1330
1536
|
}
|
|
1331
1537
|
}
|
|
1332
1538
|
|
|
@@ -1348,8 +1554,9 @@ void RemapDimensionsTransform::reverse_transform(
|
|
|
1348
1554
|
memset(x, 0, sizeof(*x) * n * d_in);
|
|
1349
1555
|
for (idx_t i = 0; i < n; i++) {
|
|
1350
1556
|
for (int j = 0; j < d_out; j++) {
|
|
1351
|
-
if (map[j] >= 0)
|
|
1557
|
+
if (map[j] >= 0) {
|
|
1352
1558
|
x[map[j]] = xt[j];
|
|
1559
|
+
}
|
|
1353
1560
|
}
|
|
1354
1561
|
x += d_in;
|
|
1355
1562
|
xt += d_out;
|
|
@@ -1360,6 +1567,7 @@ void RemapDimensionsTransform::check_identical(
|
|
|
1360
1567
|
const VectorTransform& other_in) const {
|
|
1361
1568
|
VectorTransform::check_identical(other_in);
|
|
1362
1569
|
auto other = dynamic_cast<const RemapDimensionsTransform*>(&other_in);
|
|
1363
|
-
|
|
1364
|
-
|
|
1570
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to RemapDimensionsTransform");
|
|
1571
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1572
|
+
other->map == map, "RemapDimensionsTransform maps must match");
|
|
1365
1573
|
}
|