faiss 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +38 -28
- data/ext/faiss/{index.cpp → index_rb.cpp} +64 -46
- data/ext/faiss/kmeans.cpp +10 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +5 -3
- data/ext/faiss/{utils.h → utils_rb.h} +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* @file simd_dispatch.h
|
|
12
|
+
* @brief Internal dispatch macros for SIMD level selection.
|
|
13
|
+
*
|
|
14
|
+
* This is a PRIVATE header - do not include in public APIs or user code.
|
|
15
|
+
* Only faiss internal .cpp files should include this header.
|
|
16
|
+
*
|
|
17
|
+
* For the public API (SIMDLevel enum, SIMDConfig class), use:
|
|
18
|
+
* #include <faiss/utils/simd_levels.h>
|
|
19
|
+
*/
|
|
20
|
+
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
|
22
|
+
#include <faiss/utils/simd_levels.h>
|
|
23
|
+
|
|
24
|
+
namespace faiss {
|
|
25
|
+
|
|
26
|
+
/*********************** x86 SIMD dispatch cases */
|
|
27
|
+
|
|
28
|
+
#ifdef COMPILE_SIMD_AVX2
|
|
29
|
+
#define DISPATCH_SIMDLevel_AVX2(f, ...) \
|
|
30
|
+
case SIMDLevel::AVX2: \
|
|
31
|
+
return f<SIMDLevel::AVX2>(__VA_ARGS__)
|
|
32
|
+
#else
|
|
33
|
+
#define DISPATCH_SIMDLevel_AVX2(f, ...)
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#ifdef COMPILE_SIMD_AVX512
|
|
37
|
+
#define DISPATCH_SIMDLevel_AVX512(f, ...) \
|
|
38
|
+
case SIMDLevel::AVX512: \
|
|
39
|
+
return f<SIMDLevel::AVX512>(__VA_ARGS__)
|
|
40
|
+
#else
|
|
41
|
+
#define DISPATCH_SIMDLevel_AVX512(f, ...)
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
#ifdef COMPILE_SIMD_AVX512_SPR
|
|
45
|
+
#define DISPATCH_SIMDLevel_AVX512_SPR(f, ...) \
|
|
46
|
+
case SIMDLevel::AVX512_SPR: \
|
|
47
|
+
return f<SIMDLevel::AVX512_SPR>(__VA_ARGS__)
|
|
48
|
+
#else
|
|
49
|
+
#define DISPATCH_SIMDLevel_AVX512_SPR(f, ...)
|
|
50
|
+
#endif
|
|
51
|
+
|
|
52
|
+
/*********************** ARM SIMD dispatch cases */
|
|
53
|
+
|
|
54
|
+
#ifdef COMPILE_SIMD_ARM_NEON
|
|
55
|
+
#define DISPATCH_SIMDLevel_ARM_NEON(f, ...) \
|
|
56
|
+
case SIMDLevel::ARM_NEON: \
|
|
57
|
+
return f<SIMDLevel::ARM_NEON>(__VA_ARGS__)
|
|
58
|
+
#else
|
|
59
|
+
#define DISPATCH_SIMDLevel_ARM_NEON(f, ...)
|
|
60
|
+
#endif
|
|
61
|
+
|
|
62
|
+
#ifdef COMPILE_SIMD_ARM_SVE
|
|
63
|
+
#define DISPATCH_SIMDLevel_ARM_SVE(f, ...) \
|
|
64
|
+
case SIMDLevel::ARM_SVE: \
|
|
65
|
+
return f<SIMDLevel::ARM_SVE>(__VA_ARGS__)
|
|
66
|
+
#else
|
|
67
|
+
#define DISPATCH_SIMDLevel_ARM_SVE(f, ...)
|
|
68
|
+
#endif
|
|
69
|
+
|
|
70
|
+
/*********************** Main dispatch macro */
|
|
71
|
+
|
|
72
|
+
#ifdef FAISS_ENABLE_DD
|
|
73
|
+
|
|
74
|
+
// DD mode: runtime dispatch based on SIMDConfig::level
|
|
75
|
+
#define DISPATCH_SIMDLevel(f, ...) \
|
|
76
|
+
switch (SIMDConfig::level) { \
|
|
77
|
+
case SIMDLevel::NONE: \
|
|
78
|
+
return f<SIMDLevel::NONE>(__VA_ARGS__); \
|
|
79
|
+
DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \
|
|
80
|
+
DISPATCH_SIMDLevel_AVX512(f, __VA_ARGS__); \
|
|
81
|
+
DISPATCH_SIMDLevel_AVX512_SPR(f, __VA_ARGS__); \
|
|
82
|
+
DISPATCH_SIMDLevel_ARM_NEON(f, __VA_ARGS__); \
|
|
83
|
+
DISPATCH_SIMDLevel_ARM_SVE(f, __VA_ARGS__); \
|
|
84
|
+
default: \
|
|
85
|
+
FAISS_THROW_MSG("Invalid SIMD level"); \
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
#else // Static mode
|
|
89
|
+
|
|
90
|
+
// Static mode: direct call to compiled-in SIMD level (no runtime switch)
|
|
91
|
+
#if defined(COMPILE_SIMD_AVX512_SPR)
|
|
92
|
+
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX512_SPR>(__VA_ARGS__)
|
|
93
|
+
#elif defined(COMPILE_SIMD_AVX512)
|
|
94
|
+
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX512>(__VA_ARGS__)
|
|
95
|
+
#elif defined(COMPILE_SIMD_AVX2)
|
|
96
|
+
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX2>(__VA_ARGS__)
|
|
97
|
+
#elif defined(COMPILE_SIMD_ARM_SVE)
|
|
98
|
+
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::ARM_SVE>(__VA_ARGS__)
|
|
99
|
+
#elif defined(COMPILE_SIMD_ARM_NEON)
|
|
100
|
+
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::ARM_NEON>(__VA_ARGS__)
|
|
101
|
+
#else
|
|
102
|
+
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::NONE>(__VA_ARGS__)
|
|
103
|
+
#endif
|
|
104
|
+
|
|
105
|
+
#endif // FAISS_ENABLE_DD
|
|
106
|
+
|
|
107
|
+
/**
|
|
108
|
+
* Dispatch to a lambda with SIMDLevel as a compile-time constant.
|
|
109
|
+
*
|
|
110
|
+
* This function calls the provided templated lambda with the current
|
|
111
|
+
* runtime SIMD level (from SIMDConfig::level) as a compile-time template
|
|
112
|
+
* argument. This enables SIMD-specialized code paths while keeping the
|
|
113
|
+
* dispatch logic centralized.
|
|
114
|
+
*
|
|
115
|
+
* The key benefit is that the SIMD dispatch happens once, outside any loops,
|
|
116
|
+
* so the loop body runs with the optimal SIMD implementation without
|
|
117
|
+
* per-iteration dispatch overhead.
|
|
118
|
+
*
|
|
119
|
+
* Example with a loop (the dispatch happens once, not per iteration):
|
|
120
|
+
*
|
|
121
|
+
* std::vector<float> distances(n);
|
|
122
|
+
* with_simd_level([&]<SIMDLevel level>() {
|
|
123
|
+
* for (size_t i = 0; i < n; i++) {
|
|
124
|
+
* distances[i] = fvec_L2sqr<level>(query, vectors + i * d, d);
|
|
125
|
+
* }
|
|
126
|
+
* });
|
|
127
|
+
*
|
|
128
|
+
* The lambda must be a generic lambda with a SIMDLevel template parameter.
|
|
129
|
+
*
|
|
130
|
+
* @param action A generic lambda with signature `template<SIMDLevel> T
|
|
131
|
+
* operator()()`
|
|
132
|
+
* @return The return value of the lambda
|
|
133
|
+
*/
|
|
134
|
+
template <typename LambdaType>
|
|
135
|
+
inline auto with_simd_level(LambdaType&& action) {
|
|
136
|
+
DISPATCH_SIMDLevel(action.template operator());
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
} // namespace faiss
|
|
@@ -126,8 +126,8 @@ struct StoreResultHandler : SIMDResultHandler {
|
|
|
126
126
|
|
|
127
127
|
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
128
128
|
size_t ofs = (q + i0) * ld + j0 + b * 32;
|
|
129
|
-
d0.
|
|
130
|
-
d1.
|
|
129
|
+
d0.storeu(data + ofs);
|
|
130
|
+
d1.storeu(data + ofs + 16);
|
|
131
131
|
}
|
|
132
132
|
|
|
133
133
|
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
|
@@ -406,10 +406,10 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
|
|
|
406
406
|
auto real_idx = this->adjust_id(b, j);
|
|
407
407
|
lt_mask -= 1 << j;
|
|
408
408
|
if (this->sel->is_member(real_idx)) {
|
|
409
|
-
T
|
|
410
|
-
if (C::cmp(heap_dis[0],
|
|
409
|
+
T dis_for_j = d32tab[j];
|
|
410
|
+
if (C::cmp(heap_dis[0], dis_for_j)) {
|
|
411
411
|
heap_replace_top<C>(
|
|
412
|
-
k, heap_dis, heap_ids,
|
|
412
|
+
k, heap_dis, heap_ids, dis_for_j, real_idx);
|
|
413
413
|
nup++;
|
|
414
414
|
}
|
|
415
415
|
}
|
|
@@ -419,10 +419,10 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
|
|
|
419
419
|
// find first non-zero
|
|
420
420
|
int j = __builtin_ctz(lt_mask);
|
|
421
421
|
lt_mask -= 1 << j;
|
|
422
|
-
T
|
|
423
|
-
if (C::cmp(heap_dis[0],
|
|
422
|
+
T dis_for_j = d32tab[j];
|
|
423
|
+
if (C::cmp(heap_dis[0], dis_for_j)) {
|
|
424
424
|
int64_t idx = this->adjust_id(b, j);
|
|
425
|
-
heap_replace_top<C>(k, heap_dis, heap_ids,
|
|
425
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis_for_j, idx);
|
|
426
426
|
nup++;
|
|
427
427
|
}
|
|
428
428
|
}
|
|
@@ -524,8 +524,8 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
|
524
524
|
auto real_idx = this->adjust_id(b, j);
|
|
525
525
|
lt_mask -= 1 << j;
|
|
526
526
|
if (this->sel->is_member(real_idx)) {
|
|
527
|
-
T
|
|
528
|
-
res.add(
|
|
527
|
+
T dis_for_j = d32tab[j];
|
|
528
|
+
res.add(dis_for_j, real_idx);
|
|
529
529
|
}
|
|
530
530
|
}
|
|
531
531
|
} else {
|
|
@@ -533,8 +533,8 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
|
533
533
|
// find first non-zero
|
|
534
534
|
int j = __builtin_ctz(lt_mask);
|
|
535
535
|
lt_mask -= 1 << j;
|
|
536
|
-
T
|
|
537
|
-
res.add(
|
|
536
|
+
T dis_for_j = d32tab[j];
|
|
537
|
+
res.add(dis_for_j, this->adjust_id(b, j));
|
|
538
538
|
}
|
|
539
539
|
}
|
|
540
540
|
}
|
|
@@ -761,12 +761,12 @@ void dispatch_SIMDResultHandler_fixedCW(
|
|
|
761
761
|
SIMDResultHandler& res,
|
|
762
762
|
Consumer& consumer,
|
|
763
763
|
Types... args) {
|
|
764
|
-
if (auto
|
|
765
|
-
consumer.template f<SingleResultHandler<C, W>>(*
|
|
766
|
-
} else if (auto
|
|
767
|
-
consumer.template f<HeapHandler<C, W>>(*
|
|
768
|
-
} else if (auto
|
|
769
|
-
consumer.template f<ReservoirHandler<C, W>>(*
|
|
764
|
+
if (auto resh_sh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
|
|
765
|
+
consumer.template f<SingleResultHandler<C, W>>(*resh_sh, args...);
|
|
766
|
+
} else if (auto resh_hh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
|
|
767
|
+
consumer.template f<HeapHandler<C, W>>(*resh_hh, args...);
|
|
768
|
+
} else if (auto resh_rh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
|
|
769
|
+
consumer.template f<ReservoirHandler<C, W>>(*resh_rh, args...);
|
|
770
770
|
} else { // generic handler -- will not be inlined
|
|
771
771
|
FAISS_THROW_IF_NOT_FMT(
|
|
772
772
|
simd_result_handlers_accept_virtual,
|
|
@@ -220,6 +220,9 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
|
|
|
220
220
|
if (match("RR([0-9]+)?")) {
|
|
221
221
|
return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
|
|
222
222
|
}
|
|
223
|
+
if (match("HR([0-9]+)?")) {
|
|
224
|
+
return new HadamardRotation(d, mres_to_int(sm[1], 12345));
|
|
225
|
+
}
|
|
223
226
|
if (match("ITQ([0-9]+)?")) {
|
|
224
227
|
return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
|
|
225
228
|
}
|
|
@@ -585,7 +588,7 @@ SVSStorageKind parse_lvq(const std::string& lvq_string) {
|
|
|
585
588
|
if (lvq_string == "LVQ4x8") {
|
|
586
589
|
return SVSStorageKind::SVS_LVQ4x8;
|
|
587
590
|
}
|
|
588
|
-
FAISS_ASSERT(
|
|
591
|
+
FAISS_ASSERT(false && "not supported SVS LVQ level");
|
|
589
592
|
}
|
|
590
593
|
|
|
591
594
|
SVSStorageKind parse_leanvec(const std::string& leanvec_string) {
|
|
@@ -598,7 +601,7 @@ SVSStorageKind parse_leanvec(const std::string& leanvec_string) {
|
|
|
598
601
|
if (leanvec_string == "LeanVec8x8") {
|
|
599
602
|
return SVSStorageKind::SVS_LeanVec8x8;
|
|
600
603
|
}
|
|
601
|
-
FAISS_ASSERT(
|
|
604
|
+
FAISS_ASSERT(false && "not supported SVS Leanvec level");
|
|
602
605
|
}
|
|
603
606
|
|
|
604
607
|
Index* parse_svs_datatype(
|
|
@@ -610,43 +613,49 @@ Index* parse_svs_datatype(
|
|
|
610
613
|
std::smatch sm;
|
|
611
614
|
|
|
612
615
|
if (datatype_string.empty()) {
|
|
613
|
-
if (index_type == "Vamana")
|
|
616
|
+
if (index_type == "Vamana") {
|
|
614
617
|
return new IndexSVSVamana(d, std::stoul(arg_string), mt);
|
|
615
|
-
|
|
618
|
+
}
|
|
619
|
+
if (index_type == "Flat") {
|
|
616
620
|
return new IndexSVSFlat(d, mt);
|
|
617
|
-
|
|
621
|
+
}
|
|
622
|
+
FAISS_ASSERT(false && "Unspported SVS index type");
|
|
618
623
|
}
|
|
619
624
|
if (re_match(datatype_string, "FP16", sm)) {
|
|
620
|
-
if (index_type == "Vamana")
|
|
625
|
+
if (index_type == "Vamana") {
|
|
621
626
|
return new IndexSVSVamana(
|
|
622
627
|
d, std::stoul(arg_string), mt, SVSStorageKind::SVS_FP16);
|
|
623
|
-
|
|
628
|
+
}
|
|
629
|
+
FAISS_ASSERT(false && "Unspported SVS index type for Float16");
|
|
624
630
|
}
|
|
625
631
|
if (re_match(datatype_string, "SQI8", sm)) {
|
|
626
|
-
if (index_type == "Vamana")
|
|
632
|
+
if (index_type == "Vamana") {
|
|
627
633
|
return new IndexSVSVamana(
|
|
628
634
|
d, std::stoul(arg_string), mt, SVSStorageKind::SVS_SQI8);
|
|
629
|
-
|
|
635
|
+
}
|
|
636
|
+
FAISS_ASSERT(false && "Unspported SVS index type for SQI8");
|
|
630
637
|
}
|
|
631
638
|
if (re_match(datatype_string, "(LVQ[0-9]+x[0-9]+)", sm)) {
|
|
632
|
-
if (index_type == "Vamana")
|
|
639
|
+
if (index_type == "Vamana") {
|
|
633
640
|
return new IndexSVSVamanaLVQ(
|
|
634
641
|
d, std::stoul(arg_string), mt, parse_lvq(sm[0].str()));
|
|
635
|
-
|
|
642
|
+
}
|
|
643
|
+
FAISS_ASSERT(false && "Unspported SVS index type for LVQ");
|
|
636
644
|
}
|
|
637
645
|
if (re_match(datatype_string, "(LeanVec[0-9]+x[0-9]+)(_[0-9]+)?", sm)) {
|
|
638
646
|
std::string leanvec_d_string =
|
|
639
647
|
sm[2].length() > 0 ? sm[2].str().substr(1) : "0";
|
|
640
|
-
int leanvec_d = std::stoul(leanvec_d_string);
|
|
648
|
+
int leanvec_d = static_cast<int>(std::stoul(leanvec_d_string));
|
|
641
649
|
|
|
642
|
-
if (index_type == "Vamana")
|
|
650
|
+
if (index_type == "Vamana") {
|
|
643
651
|
return new IndexSVSVamanaLeanVec(
|
|
644
652
|
d,
|
|
645
653
|
std::stoul(arg_string),
|
|
646
654
|
mt,
|
|
647
655
|
leanvec_d,
|
|
648
656
|
parse_leanvec(sm[1].str()));
|
|
649
|
-
|
|
657
|
+
}
|
|
658
|
+
FAISS_ASSERT(false && "Unspported SVS index type for LeanVec");
|
|
650
659
|
}
|
|
651
660
|
return nullptr;
|
|
652
661
|
}
|
|
@@ -659,7 +668,6 @@ Index* parse_IndexSVS(const std::string& code_string, int d, MetricType mt) {
|
|
|
659
668
|
return parse_svs_datatype("Flat", "", datatype_string, d, mt);
|
|
660
669
|
}
|
|
661
670
|
if (re_match(code_string, "Vamana([0-9]+)(,.+)?", sm)) {
|
|
662
|
-
Index* index{nullptr};
|
|
663
671
|
std::string degree_string = sm[1].str();
|
|
664
672
|
std::string datatype_string =
|
|
665
673
|
sm[2].length() > 0 ? sm[2].str().substr(1) : "";
|
|
@@ -667,7 +675,7 @@ Index* parse_IndexSVS(const std::string& code_string, int d, MetricType mt) {
|
|
|
667
675
|
"Vamana", degree_string, datatype_string, d, mt);
|
|
668
676
|
}
|
|
669
677
|
if (re_match(code_string, "IVF([0-9]+)(,.+)?", sm)) {
|
|
670
|
-
FAISS_ASSERT(
|
|
678
|
+
FAISS_ASSERT(false && "Unspported SVS index type");
|
|
671
679
|
}
|
|
672
680
|
return nullptr;
|
|
673
681
|
}
|
|
@@ -703,6 +711,17 @@ Index* parse_other_indexes(
|
|
|
703
711
|
}
|
|
704
712
|
}
|
|
705
713
|
|
|
714
|
+
// IndexFlatIPPanorama
|
|
715
|
+
if (match("FlatIPPanorama([0-9]+)(_[0-9]+)?")) {
|
|
716
|
+
FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT);
|
|
717
|
+
int nlevels = std::stoi(sm[1].str());
|
|
718
|
+
if (sm[2].length() == 0) {
|
|
719
|
+
return new IndexFlatIPPanorama(d, nlevels);
|
|
720
|
+
}
|
|
721
|
+
int batch_size = std::stoi(sm[2].str().substr(1));
|
|
722
|
+
return new IndexFlatIPPanorama(d, nlevels, (size_t)batch_size);
|
|
723
|
+
}
|
|
724
|
+
|
|
706
725
|
// IndexLSH
|
|
707
726
|
if (match("LSH([0-9]*)(r?)(t?)")) {
|
|
708
727
|
int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
|
|
@@ -11,13 +11,17 @@
|
|
|
11
11
|
#define FAISS_INDEX_IO_H
|
|
12
12
|
|
|
13
13
|
#include <cstdio>
|
|
14
|
+
#include <memory>
|
|
14
15
|
|
|
15
16
|
/** I/O functions can read/write to a filename, a file handle or to an
|
|
16
17
|
* object that abstracts the medium.
|
|
17
18
|
*
|
|
18
|
-
* The read functions
|
|
19
|
-
*
|
|
20
|
-
*
|
|
19
|
+
* The read functions come in two forms:
|
|
20
|
+
* - read_*_up() returns a std::unique_ptr that owns the result.
|
|
21
|
+
* - read_*() returns a raw pointer for backward compatibility.
|
|
22
|
+
* The caller is responsible for deleting the returned object.
|
|
23
|
+
*
|
|
24
|
+
* All references within these objects are owned by the object.
|
|
21
25
|
*/
|
|
22
26
|
|
|
23
27
|
namespace faiss {
|
|
@@ -68,25 +72,47 @@ Index* read_index(const char* fname, int io_flags = 0);
|
|
|
68
72
|
Index* read_index(FILE* f, int io_flags = 0);
|
|
69
73
|
Index* read_index(IOReader* reader, int io_flags = 0);
|
|
70
74
|
|
|
75
|
+
std::unique_ptr<Index> read_index_up(const char* fname, int io_flags = 0);
|
|
76
|
+
std::unique_ptr<Index> read_index_up(FILE* f, int io_flags = 0);
|
|
77
|
+
std::unique_ptr<Index> read_index_up(IOReader* reader, int io_flags = 0);
|
|
78
|
+
|
|
71
79
|
IndexBinary* read_index_binary(const char* fname, int io_flags = 0);
|
|
72
80
|
IndexBinary* read_index_binary(FILE* f, int io_flags = 0);
|
|
73
81
|
IndexBinary* read_index_binary(IOReader* reader, int io_flags = 0);
|
|
74
82
|
|
|
83
|
+
std::unique_ptr<IndexBinary> read_index_binary_up(
|
|
84
|
+
const char* fname,
|
|
85
|
+
int io_flags = 0);
|
|
86
|
+
std::unique_ptr<IndexBinary> read_index_binary_up(FILE* f, int io_flags = 0);
|
|
87
|
+
std::unique_ptr<IndexBinary> read_index_binary_up(
|
|
88
|
+
IOReader* reader,
|
|
89
|
+
int io_flags = 0);
|
|
90
|
+
|
|
75
91
|
void write_VectorTransform(const VectorTransform* vt, const char* fname);
|
|
76
92
|
void write_VectorTransform(const VectorTransform* vt, IOWriter* f);
|
|
77
93
|
|
|
78
94
|
VectorTransform* read_VectorTransform(const char* fname);
|
|
79
95
|
VectorTransform* read_VectorTransform(IOReader* f);
|
|
80
96
|
|
|
97
|
+
std::unique_ptr<VectorTransform> read_VectorTransform_up(const char* fname);
|
|
98
|
+
std::unique_ptr<VectorTransform> read_VectorTransform_up(IOReader* f);
|
|
99
|
+
|
|
81
100
|
ProductQuantizer* read_ProductQuantizer(const char* fname);
|
|
82
101
|
ProductQuantizer* read_ProductQuantizer(IOReader* reader);
|
|
83
102
|
|
|
103
|
+
std::unique_ptr<ProductQuantizer> read_ProductQuantizer_up(const char* fname);
|
|
104
|
+
std::unique_ptr<ProductQuantizer> read_ProductQuantizer_up(IOReader* reader);
|
|
105
|
+
|
|
84
106
|
void write_ProductQuantizer(const ProductQuantizer* pq, const char* fname);
|
|
85
107
|
void write_ProductQuantizer(const ProductQuantizer* pq, IOWriter* f);
|
|
86
108
|
|
|
87
109
|
void write_InvertedLists(const InvertedLists* ils, IOWriter* f);
|
|
88
110
|
InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0);
|
|
89
111
|
|
|
112
|
+
std::unique_ptr<InvertedLists> read_InvertedLists_up(
|
|
113
|
+
IOReader* reader,
|
|
114
|
+
int io_flags = 0);
|
|
115
|
+
|
|
90
116
|
} // namespace faiss
|
|
91
117
|
|
|
92
118
|
#endif
|
|
@@ -7,6 +7,8 @@
|
|
|
7
7
|
|
|
8
8
|
#include <faiss/invlists/BlockInvertedLists.h>
|
|
9
9
|
|
|
10
|
+
#include <memory>
|
|
11
|
+
|
|
10
12
|
#include <faiss/impl/CodePacker.h>
|
|
11
13
|
#include <faiss/impl/FaissAssert.h>
|
|
12
14
|
#include <faiss/impl/IDSelector.h>
|
|
@@ -81,7 +83,7 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
|
|
|
81
83
|
|
|
82
84
|
size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
|
|
83
85
|
idx_t nremove = 0;
|
|
84
|
-
#pragma omp parallel for
|
|
86
|
+
#pragma omp parallel for reduction(+ : nremove)
|
|
85
87
|
for (idx_t i = 0; i < nlist; i++) {
|
|
86
88
|
std::vector<uint8_t> buffer(packer->code_size);
|
|
87
89
|
idx_t l = ids[i].size(), j = 0;
|
|
@@ -95,8 +97,9 @@ size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
|
|
|
95
97
|
j++;
|
|
96
98
|
}
|
|
97
99
|
}
|
|
100
|
+
idx_t orig_size = ids[i].size();
|
|
98
101
|
resize(i, l);
|
|
99
|
-
nremove +=
|
|
102
|
+
nremove += orig_size - l;
|
|
100
103
|
}
|
|
101
104
|
|
|
102
105
|
return nremove;
|
|
@@ -160,7 +163,7 @@ void BlockInvertedListsIOHook::write(const InvertedLists* ils_in, IOWriter* f)
|
|
|
160
163
|
|
|
161
164
|
InvertedLists* BlockInvertedListsIOHook::read(IOReader* f, int /* io_flags */)
|
|
162
165
|
const {
|
|
163
|
-
|
|
166
|
+
auto il = std::make_unique<BlockInvertedLists>();
|
|
164
167
|
READ1(il->nlist);
|
|
165
168
|
READ1(il->code_size);
|
|
166
169
|
READ1(il->n_per_block);
|
|
@@ -174,7 +177,7 @@ InvertedLists* BlockInvertedListsIOHook::read(IOReader* f, int /* io_flags */)
|
|
|
174
177
|
READVECTOR(il->codes[i]);
|
|
175
178
|
}
|
|
176
179
|
|
|
177
|
-
return il;
|
|
180
|
+
return il.release();
|
|
178
181
|
}
|
|
179
182
|
|
|
180
183
|
} // namespace faiss
|
|
@@ -45,18 +45,6 @@
|
|
|
45
45
|
// create svs_runtime as alias for svs::runtime::FAISS_SVS_RUNTIME_VERSION
|
|
46
46
|
SVS_RUNTIME_CREATE_API_ALIAS(svs_runtime, FAISS_SVS_RUNTIME_VERSION);
|
|
47
47
|
|
|
48
|
-
// SVS forward declarations
|
|
49
|
-
namespace svs {
|
|
50
|
-
namespace runtime {
|
|
51
|
-
inline namespace v0 {
|
|
52
|
-
struct FlatIndex;
|
|
53
|
-
struct VamanaIndex;
|
|
54
|
-
struct DynamicVamanaIndex;
|
|
55
|
-
struct LeanVecTrainingData;
|
|
56
|
-
} // namespace v0
|
|
57
|
-
} // namespace runtime
|
|
58
|
-
} // namespace svs
|
|
59
|
-
|
|
60
48
|
namespace faiss {
|
|
61
49
|
|
|
62
50
|
inline svs_runtime::MetricType to_svs_metric(faiss::MetricType metric) {
|
|
@@ -66,7 +54,7 @@ inline svs_runtime::MetricType to_svs_metric(faiss::MetricType metric) {
|
|
|
66
54
|
case METRIC_L2:
|
|
67
55
|
return svs_runtime::MetricType::L2;
|
|
68
56
|
default:
|
|
69
|
-
FAISS_ASSERT(
|
|
57
|
+
FAISS_ASSERT(false && "not supported SVS distance");
|
|
70
58
|
}
|
|
71
59
|
}
|
|
72
60
|
|
|
@@ -93,7 +81,8 @@ template <typename T, typename U, typename = void>
|
|
|
93
81
|
struct InputBufferConverter {
|
|
94
82
|
InputBufferConverter(std::span<const U> data = {}) : buffer(data.size()) {
|
|
95
83
|
FAISS_ASSERT(
|
|
96
|
-
|
|
84
|
+
false &&
|
|
85
|
+
"InputBufferConverter: there is no suitable user code for this type conversion");
|
|
97
86
|
std::transform(
|
|
98
87
|
data.begin(), data.end(), buffer.begin(), [](const U& val) {
|
|
99
88
|
return static_cast<T>(val);
|
|
@@ -118,8 +107,8 @@ struct InputBufferConverter {
|
|
|
118
107
|
std::vector<T> buffer;
|
|
119
108
|
};
|
|
120
109
|
|
|
121
|
-
// Specialization for reinterpret cast when types are integral and have
|
|
122
|
-
// size
|
|
110
|
+
// Specialization for reinterpret cast when types are integral and have
|
|
111
|
+
// the same size
|
|
123
112
|
template <typename T, typename U>
|
|
124
113
|
struct InputBufferConverter<
|
|
125
114
|
T,
|
|
@@ -153,7 +142,8 @@ struct OutputBufferConverter {
|
|
|
153
142
|
OutputBufferConverter(std::span<U> data = {})
|
|
154
143
|
: data_span(data), buffer(data.size()) {
|
|
155
144
|
FAISS_ASSERT(
|
|
156
|
-
|
|
145
|
+
false &&
|
|
146
|
+
"OutputBufferConverter: there is no suitable user code for this type conversion");
|
|
157
147
|
}
|
|
158
148
|
|
|
159
149
|
~OutputBufferConverter() {
|
|
@@ -176,8 +166,8 @@ struct OutputBufferConverter {
|
|
|
176
166
|
std::vector<T> buffer;
|
|
177
167
|
};
|
|
178
168
|
|
|
179
|
-
// Specialization for reinterpret cast when types are integral and have
|
|
180
|
-
// size
|
|
169
|
+
// Specialization for reinterpret cast when types are integral and have
|
|
170
|
+
// the same size
|
|
181
171
|
template <typename T, typename U>
|
|
182
172
|
struct OutputBufferConverter<
|
|
183
173
|
T,
|
|
@@ -27,6 +27,7 @@
|
|
|
27
27
|
#include <faiss/svs/IndexSVSFaissUtils.h>
|
|
28
28
|
|
|
29
29
|
#include <svs/runtime/api_defs.h>
|
|
30
|
+
#include <svs/runtime/dynamic_vamana_index.h>
|
|
30
31
|
|
|
31
32
|
#include <iostream>
|
|
32
33
|
|
|
@@ -71,7 +72,7 @@ inline svs_runtime::StorageKind to_svs_storage_kind(SVSStorageKind kind) {
|
|
|
71
72
|
case SVS_LeanVec8x8:
|
|
72
73
|
return svs_runtime::StorageKind::LeanVec8x8;
|
|
73
74
|
default:
|
|
74
|
-
FAISS_ASSERT(
|
|
75
|
+
FAISS_ASSERT(false && "not supported SVS storage kind");
|
|
75
76
|
}
|
|
76
77
|
}
|
|
77
78
|
|
|
@@ -66,6 +66,14 @@ void IndexSVSVamanaLeanVec::add(idx_t n, const float* x) {
|
|
|
66
66
|
}
|
|
67
67
|
|
|
68
68
|
void IndexSVSVamanaLeanVec::train(idx_t n, const float* x) {
|
|
69
|
+
train(n, x, 0, nullptr);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
void IndexSVSVamanaLeanVec::train(
|
|
73
|
+
idx_t n,
|
|
74
|
+
const float* x,
|
|
75
|
+
idx_t n_train_q,
|
|
76
|
+
const float* queries) {
|
|
69
77
|
FAISS_THROW_IF_MSG(
|
|
70
78
|
training_data || impl, "Index already trained or contains data.");
|
|
71
79
|
|
|
@@ -74,7 +82,7 @@ void IndexSVSVamanaLeanVec::train(idx_t n, const float* x) {
|
|
|
74
82
|
"LVQ/LeanVec support not available on this platform or build");
|
|
75
83
|
|
|
76
84
|
auto status = svs_runtime::LeanVecTrainingData::build(
|
|
77
|
-
&training_data, d, n, x, leanvec_d);
|
|
85
|
+
&training_data, d, n, x, n_train_q, queries, leanvec_d);
|
|
78
86
|
if (!status.ok()) {
|
|
79
87
|
FAISS_THROW_MSG(status.message());
|
|
80
88
|
}
|
|
@@ -41,8 +41,17 @@ struct IndexSVSVamanaLeanVec : IndexSVSVamana {
|
|
|
41
41
|
|
|
42
42
|
void add(idx_t n, const float* x) override;
|
|
43
43
|
|
|
44
|
+
/* Default train assumes in-distribution data */
|
|
44
45
|
void train(idx_t n, const float* x) override;
|
|
45
46
|
|
|
47
|
+
/* Generic train with out-of-distribution parameters.
|
|
48
|
+
* Out-of-distribution (OOD) means database vectors and queries _can_ be
|
|
49
|
+
* sampled from different distributions (e.g., cross-modal). More details in
|
|
50
|
+
* the original publication, arXiv:2312.16335.
|
|
51
|
+
*/
|
|
52
|
+
void train(idx_t n, const float* x, idx_t n_train_q, const float* xq_train)
|
|
53
|
+
override;
|
|
54
|
+
|
|
46
55
|
void serialize_training_data(std::ostream& out) const;
|
|
47
56
|
void deserialize_training_data(std::istream& in);
|
|
48
57
|
|
|
@@ -254,4 +254,50 @@ INSTANTIATE(CMax, float);
|
|
|
254
254
|
INSTANTIATE(CMin, int32_t);
|
|
255
255
|
INSTANTIATE(CMax, int32_t);
|
|
256
256
|
|
|
257
|
+
/**********************************************************
|
|
258
|
+
* reorder_2_heaps
|
|
259
|
+
**********************************************************/
|
|
260
|
+
|
|
261
|
+
template <class C>
|
|
262
|
+
void reorder_2_heaps(
|
|
263
|
+
int64_t n,
|
|
264
|
+
int64_t k,
|
|
265
|
+
typename C::TI* __restrict labels,
|
|
266
|
+
float* __restrict distances,
|
|
267
|
+
int64_t k_base,
|
|
268
|
+
const typename C::TI* __restrict base_labels,
|
|
269
|
+
const float* __restrict base_distances) {
|
|
270
|
+
#pragma omp parallel for if (n > 1)
|
|
271
|
+
for (int64_t i = 0; i < n; i++) {
|
|
272
|
+
typename C::TI* idxo = labels + i * k;
|
|
273
|
+
float* diso = distances + i * k;
|
|
274
|
+
const typename C::TI* idxi = base_labels + i * k_base;
|
|
275
|
+
const float* disi = base_distances + i * k_base;
|
|
276
|
+
|
|
277
|
+
heap_heapify<C>(k, diso, idxo, disi, idxi, k);
|
|
278
|
+
if (k_base != k) { // add remaining elements
|
|
279
|
+
heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
|
|
280
|
+
}
|
|
281
|
+
heap_reorder<C>(k, diso, idxo);
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
template void reorder_2_heaps<CMax<float, int64_t>>(
|
|
286
|
+
int64_t n,
|
|
287
|
+
int64_t k,
|
|
288
|
+
int64_t* __restrict labels,
|
|
289
|
+
float* __restrict distances,
|
|
290
|
+
int64_t k_base,
|
|
291
|
+
const int64_t* __restrict base_labels,
|
|
292
|
+
const float* __restrict base_distances);
|
|
293
|
+
|
|
294
|
+
template void reorder_2_heaps<CMin<float, int64_t>>(
|
|
295
|
+
int64_t n,
|
|
296
|
+
int64_t k,
|
|
297
|
+
int64_t* __restrict labels,
|
|
298
|
+
float* __restrict distances,
|
|
299
|
+
int64_t k_base,
|
|
300
|
+
const int64_t* __restrict base_labels,
|
|
301
|
+
const float* __restrict base_distances);
|
|
302
|
+
|
|
257
303
|
} // namespace faiss
|