faiss 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +20 -2
|
@@ -9,18 +9,19 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/ScalarQuantizer.h>
|
|
11
11
|
|
|
12
|
-
#include <cstdio>
|
|
13
12
|
#include <algorithm>
|
|
13
|
+
#include <cstdio>
|
|
14
14
|
|
|
15
|
+
#include <faiss/impl/platform_macros.h>
|
|
15
16
|
#include <omp.h>
|
|
16
17
|
|
|
17
18
|
#ifdef __SSE__
|
|
18
19
|
#include <immintrin.h>
|
|
19
20
|
#endif
|
|
20
21
|
|
|
21
|
-
#include <faiss/utils/utils.h>
|
|
22
|
-
#include <faiss/impl/FaissAssert.h>
|
|
23
22
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
23
|
+
#include <faiss/impl/FaissAssert.h>
|
|
24
|
+
#include <faiss/utils/utils.h>
|
|
24
25
|
|
|
25
26
|
namespace faiss {
|
|
26
27
|
|
|
@@ -43,11 +44,11 @@ namespace faiss {
|
|
|
43
44
|
#ifdef __F16C__
|
|
44
45
|
#define USE_F16C
|
|
45
46
|
#else
|
|
46
|
-
#warning
|
|
47
|
+
#warning \
|
|
48
|
+
"Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well"
|
|
47
49
|
#endif
|
|
48
50
|
#endif
|
|
49
51
|
|
|
50
|
-
|
|
51
52
|
namespace {
|
|
52
53
|
|
|
53
54
|
typedef Index::idx_t idx_t;
|
|
@@ -55,7 +56,6 @@ typedef ScalarQuantizer::QuantizerType QuantizerType;
|
|
|
55
56
|
typedef ScalarQuantizer::RangeStat RangeStat;
|
|
56
57
|
using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
|
|
57
58
|
|
|
58
|
-
|
|
59
59
|
/*******************************************************************
|
|
60
60
|
* Codec: converts between values in [0, 1] and an index in a code
|
|
61
61
|
* array. The "i" parameter is the vector component index (not byte
|
|
@@ -63,108 +63,103 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
|
|
|
63
63
|
*/
|
|
64
64
|
|
|
65
65
|
struct Codec8bit {
|
|
66
|
-
|
|
67
|
-
static void encode_component (float x, uint8_t *code, int i) {
|
|
66
|
+
static void encode_component(float x, uint8_t* code, int i) {
|
|
68
67
|
code[i] = (int)(255 * x);
|
|
69
68
|
}
|
|
70
69
|
|
|
71
|
-
static float decode_component
|
|
70
|
+
static float decode_component(const uint8_t* code, int i) {
|
|
72
71
|
return (code[i] + 0.5f) / 255.0f;
|
|
73
72
|
}
|
|
74
73
|
|
|
75
74
|
#ifdef __AVX2__
|
|
76
|
-
static __m256 decode_8_components
|
|
75
|
+
static __m256 decode_8_components(const uint8_t* code, int i) {
|
|
77
76
|
uint64_t c8 = *(uint64_t*)(code + i);
|
|
78
|
-
__m128i c4lo = _mm_cvtepu8_epi32
|
|
79
|
-
__m128i c4hi = _mm_cvtepu8_epi32
|
|
77
|
+
__m128i c4lo = _mm_cvtepu8_epi32(_mm_set1_epi32(c8));
|
|
78
|
+
__m128i c4hi = _mm_cvtepu8_epi32(_mm_set1_epi32(c8 >> 32));
|
|
80
79
|
// __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
|
|
81
|
-
__m256i i8 = _mm256_castsi128_si256
|
|
82
|
-
i8 = _mm256_insertf128_si256
|
|
83
|
-
__m256 f8 = _mm256_cvtepi32_ps
|
|
84
|
-
__m256 half = _mm256_set1_ps
|
|
85
|
-
f8
|
|
86
|
-
__m256 one_255 = _mm256_set1_ps
|
|
87
|
-
return f8
|
|
80
|
+
__m256i i8 = _mm256_castsi128_si256(c4lo);
|
|
81
|
+
i8 = _mm256_insertf128_si256(i8, c4hi, 1);
|
|
82
|
+
__m256 f8 = _mm256_cvtepi32_ps(i8);
|
|
83
|
+
__m256 half = _mm256_set1_ps(0.5f);
|
|
84
|
+
f8 = _mm256_add_ps(f8, half);
|
|
85
|
+
__m256 one_255 = _mm256_set1_ps(1.f / 255.f);
|
|
86
|
+
return _mm256_mul_ps(f8, one_255);
|
|
88
87
|
}
|
|
89
88
|
#endif
|
|
90
89
|
};
|
|
91
90
|
|
|
92
|
-
|
|
93
91
|
struct Codec4bit {
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
code [i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
|
|
92
|
+
static void encode_component(float x, uint8_t* code, int i) {
|
|
93
|
+
code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
|
|
97
94
|
}
|
|
98
95
|
|
|
99
|
-
static float decode_component
|
|
96
|
+
static float decode_component(const uint8_t* code, int i) {
|
|
100
97
|
return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
|
|
101
98
|
}
|
|
102
99
|
|
|
103
|
-
|
|
104
100
|
#ifdef __AVX2__
|
|
105
|
-
static __m256 decode_8_components
|
|
101
|
+
static __m256 decode_8_components(const uint8_t* code, int i) {
|
|
106
102
|
uint32_t c4 = *(uint32_t*)(code + (i >> 1));
|
|
107
103
|
uint32_t mask = 0x0f0f0f0f;
|
|
108
104
|
uint32_t c4ev = c4 & mask;
|
|
109
105
|
uint32_t c4od = (c4 >> 4) & mask;
|
|
110
106
|
|
|
111
107
|
// the 8 lower bytes of c8 contain the values
|
|
112
|
-
__m128i c8 =
|
|
113
|
-
|
|
114
|
-
__m128i c4lo = _mm_cvtepu8_epi32
|
|
115
|
-
__m128i c4hi = _mm_cvtepu8_epi32
|
|
116
|
-
__m256i i8 = _mm256_castsi128_si256
|
|
117
|
-
i8 = _mm256_insertf128_si256
|
|
118
|
-
__m256 f8 = _mm256_cvtepi32_ps
|
|
119
|
-
__m256 half = _mm256_set1_ps
|
|
120
|
-
f8
|
|
121
|
-
__m256 one_255 = _mm256_set1_ps
|
|
122
|
-
return f8
|
|
108
|
+
__m128i c8 =
|
|
109
|
+
_mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od));
|
|
110
|
+
__m128i c4lo = _mm_cvtepu8_epi32(c8);
|
|
111
|
+
__m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4));
|
|
112
|
+
__m256i i8 = _mm256_castsi128_si256(c4lo);
|
|
113
|
+
i8 = _mm256_insertf128_si256(i8, c4hi, 1);
|
|
114
|
+
__m256 f8 = _mm256_cvtepi32_ps(i8);
|
|
115
|
+
__m256 half = _mm256_set1_ps(0.5f);
|
|
116
|
+
f8 = _mm256_add_ps(f8, half);
|
|
117
|
+
__m256 one_255 = _mm256_set1_ps(1.f / 15.f);
|
|
118
|
+
return _mm256_mul_ps(f8, one_255);
|
|
123
119
|
}
|
|
124
120
|
#endif
|
|
125
121
|
};
|
|
126
122
|
|
|
127
123
|
struct Codec6bit {
|
|
128
|
-
|
|
129
|
-
static void encode_component (float x, uint8_t *code, int i) {
|
|
124
|
+
static void encode_component(float x, uint8_t* code, int i) {
|
|
130
125
|
int bits = (int)(x * 63.0);
|
|
131
126
|
code += (i >> 2) * 3;
|
|
132
|
-
switch(i & 3) {
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
127
|
+
switch (i & 3) {
|
|
128
|
+
case 0:
|
|
129
|
+
code[0] |= bits;
|
|
130
|
+
break;
|
|
131
|
+
case 1:
|
|
132
|
+
code[0] |= bits << 6;
|
|
133
|
+
code[1] |= bits >> 2;
|
|
134
|
+
break;
|
|
135
|
+
case 2:
|
|
136
|
+
code[1] |= bits << 4;
|
|
137
|
+
code[2] |= bits >> 4;
|
|
138
|
+
break;
|
|
139
|
+
case 3:
|
|
140
|
+
code[2] |= bits << 2;
|
|
141
|
+
break;
|
|
147
142
|
}
|
|
148
143
|
}
|
|
149
144
|
|
|
150
|
-
static float decode_component
|
|
145
|
+
static float decode_component(const uint8_t* code, int i) {
|
|
151
146
|
uint8_t bits;
|
|
152
147
|
code += (i >> 2) * 3;
|
|
153
|
-
switch(i & 3) {
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
148
|
+
switch (i & 3) {
|
|
149
|
+
case 0:
|
|
150
|
+
bits = code[0] & 0x3f;
|
|
151
|
+
break;
|
|
152
|
+
case 1:
|
|
153
|
+
bits = code[0] >> 6;
|
|
154
|
+
bits |= (code[1] & 0xf) << 2;
|
|
155
|
+
break;
|
|
156
|
+
case 2:
|
|
157
|
+
bits = code[1] >> 4;
|
|
158
|
+
bits |= (code[2] & 3) << 4;
|
|
159
|
+
break;
|
|
160
|
+
case 3:
|
|
161
|
+
bits = code[2] >> 2;
|
|
162
|
+
break;
|
|
168
163
|
}
|
|
169
164
|
return (bits + 0.5f) / 63.0f;
|
|
170
165
|
}
|
|
@@ -173,12 +168,14 @@ struct Codec6bit {
|
|
|
173
168
|
|
|
174
169
|
/* Load 6 bytes that represent 8 6-bit values, return them as a
|
|
175
170
|
* 8*32 bit vector register */
|
|
176
|
-
static __m256i load6
|
|
177
|
-
const __m128i perm = _mm_set_epi8(
|
|
171
|
+
static __m256i load6(const uint16_t* code16) {
|
|
172
|
+
const __m128i perm = _mm_set_epi8(
|
|
173
|
+
-1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
|
|
178
174
|
const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
|
|
179
175
|
|
|
180
176
|
// load 6 bytes
|
|
181
|
-
__m128i c1 =
|
|
177
|
+
__m128i c1 =
|
|
178
|
+
_mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
|
|
182
179
|
|
|
183
180
|
// put in 8 * 32 bits
|
|
184
181
|
__m128i c2 = _mm_shuffle_epi8(c1, perm);
|
|
@@ -190,37 +187,33 @@ struct Codec6bit {
|
|
|
190
187
|
return c5;
|
|
191
188
|
}
|
|
192
189
|
|
|
193
|
-
static __m256 decode_8_components
|
|
194
|
-
__m256i i8 = load6
|
|
195
|
-
__m256 f8 = _mm256_cvtepi32_ps
|
|
190
|
+
static __m256 decode_8_components(const uint8_t* code, int i) {
|
|
191
|
+
__m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
|
|
192
|
+
__m256 f8 = _mm256_cvtepi32_ps(i8);
|
|
196
193
|
// this could also be done with bit manipulations but it is
|
|
197
194
|
// not obviously faster
|
|
198
|
-
__m256 half = _mm256_set1_ps
|
|
199
|
-
f8
|
|
200
|
-
__m256 one_63 = _mm256_set1_ps
|
|
201
|
-
return f8
|
|
195
|
+
__m256 half = _mm256_set1_ps(0.5f);
|
|
196
|
+
f8 = _mm256_add_ps(f8, half);
|
|
197
|
+
__m256 one_63 = _mm256_set1_ps(1.f / 63.f);
|
|
198
|
+
return _mm256_mul_ps(f8, one_63);
|
|
202
199
|
}
|
|
203
200
|
|
|
204
201
|
#endif
|
|
205
202
|
};
|
|
206
203
|
|
|
207
|
-
|
|
208
|
-
|
|
209
204
|
#ifdef USE_F16C
|
|
210
205
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
return _mm_cvtsi128_si32 (xi) & 0xffff;
|
|
206
|
+
uint16_t encode_fp16(float x) {
|
|
207
|
+
__m128 xf = _mm_set1_ps(x);
|
|
208
|
+
__m128i xi =
|
|
209
|
+
_mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
|
210
|
+
return _mm_cvtsi128_si32(xi) & 0xffff;
|
|
217
211
|
}
|
|
218
212
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
return _mm_cvtss_f32 (xf);
|
|
213
|
+
float decode_fp16(uint16_t x) {
|
|
214
|
+
__m128i xi = _mm_set1_epi16(x);
|
|
215
|
+
__m128 xf = _mm_cvtph_ps(xi);
|
|
216
|
+
return _mm_cvtss_f32(xf);
|
|
224
217
|
}
|
|
225
218
|
|
|
226
219
|
#else
|
|
@@ -228,19 +221,17 @@ float decode_fp16 (uint16_t x) {
|
|
|
228
221
|
// non-intrinsic FP16 <-> FP32 code adapted from
|
|
229
222
|
// https://github.com/ispc/ispc/blob/master/stdlib.ispc
|
|
230
223
|
|
|
231
|
-
float floatbits
|
|
232
|
-
void
|
|
224
|
+
float floatbits(uint32_t x) {
|
|
225
|
+
void* xptr = &x;
|
|
233
226
|
return *(float*)xptr;
|
|
234
227
|
}
|
|
235
228
|
|
|
236
|
-
uint32_t intbits
|
|
237
|
-
void
|
|
229
|
+
uint32_t intbits(float f) {
|
|
230
|
+
void* fptr = &f;
|
|
238
231
|
return *(uint32_t*)fptr;
|
|
239
232
|
}
|
|
240
233
|
|
|
241
|
-
|
|
242
|
-
uint16_t encode_fp16 (float f) {
|
|
243
|
-
|
|
234
|
+
uint16_t encode_fp16(float f) {
|
|
244
235
|
// via Fabian "ryg" Giesen.
|
|
245
236
|
// https://gist.github.com/2156668
|
|
246
237
|
uint32_t sign_mask = 0x80000000u;
|
|
@@ -297,20 +288,19 @@ uint16_t encode_fp16 (float f) {
|
|
|
297
288
|
return (o | (sign >> 16));
|
|
298
289
|
}
|
|
299
290
|
|
|
300
|
-
float decode_fp16
|
|
301
|
-
|
|
291
|
+
float decode_fp16(uint16_t h) {
|
|
302
292
|
// https://gist.github.com/2144712
|
|
303
293
|
// Fabian "ryg" Giesen.
|
|
304
294
|
|
|
305
295
|
const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
|
|
306
296
|
|
|
307
|
-
int32_t o = ((int32_t)(h & 0x7fffu)) << 13;
|
|
308
|
-
int32_t exp = shifted_exp & o;
|
|
309
|
-
o += (int32_t)(127 - 15) << 23;
|
|
297
|
+
int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
|
|
298
|
+
int32_t exp = shifted_exp & o; // just the exponent
|
|
299
|
+
o += (int32_t)(127 - 15) << 23; // exponent adjust
|
|
310
300
|
|
|
311
301
|
int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
|
|
312
|
-
int32_t zerodenorm_val =
|
|
313
|
-
|
|
302
|
+
int32_t zerodenorm_val =
|
|
303
|
+
intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
|
|
314
304
|
int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
|
|
315
305
|
|
|
316
306
|
int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
|
|
@@ -319,30 +309,21 @@ float decode_fp16 (uint16_t h) {
|
|
|
319
309
|
|
|
320
310
|
#endif
|
|
321
311
|
|
|
322
|
-
|
|
323
|
-
|
|
324
312
|
/*******************************************************************
|
|
325
313
|
* Quantizer: normalizes scalar vector components, then passes them
|
|
326
314
|
* through a codec
|
|
327
315
|
*******************************************************************/
|
|
328
316
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
template<class Codec, bool uniform, int SIMD>
|
|
317
|
+
template <class Codec, bool uniform, int SIMD>
|
|
334
318
|
struct QuantizerTemplate {};
|
|
335
319
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
|
|
320
|
+
template <class Codec>
|
|
321
|
+
struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::Quantizer {
|
|
339
322
|
const size_t d;
|
|
340
323
|
const float vmin, vdiff;
|
|
341
324
|
|
|
342
|
-
QuantizerTemplate(size_t d, const std::vector<float
|
|
343
|
-
|
|
344
|
-
{
|
|
345
|
-
}
|
|
325
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
|
326
|
+
: d(d), vmin(trained[0]), vdiff(trained[1]) {}
|
|
346
327
|
|
|
347
328
|
void encode_vector(const float* x, uint8_t* code) const final {
|
|
348
329
|
for (size_t i = 0; i < d; i++) {
|
|
@@ -367,43 +348,36 @@ struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
|
|
|
367
348
|
}
|
|
368
349
|
}
|
|
369
350
|
|
|
370
|
-
float reconstruct_component
|
|
371
|
-
|
|
372
|
-
float xi = Codec::decode_component (code, i);
|
|
351
|
+
float reconstruct_component(const uint8_t* code, int i) const {
|
|
352
|
+
float xi = Codec::decode_component(code, i);
|
|
373
353
|
return vmin + xi * vdiff;
|
|
374
354
|
}
|
|
375
|
-
|
|
376
355
|
};
|
|
377
356
|
|
|
378
|
-
|
|
379
|
-
|
|
380
357
|
#ifdef __AVX2__
|
|
381
358
|
|
|
382
|
-
template<class Codec>
|
|
383
|
-
struct QuantizerTemplate<Codec, true, 8
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
QuantizerTemplate<Codec, true, 1> (d, trained) {}
|
|
359
|
+
template <class Codec>
|
|
360
|
+
struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
|
|
361
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
|
362
|
+
: QuantizerTemplate<Codec, true, 1>(d, trained) {}
|
|
387
363
|
|
|
388
|
-
__m256 reconstruct_8_components
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
364
|
+
__m256 reconstruct_8_components(const uint8_t* code, int i) const {
|
|
365
|
+
__m256 xi = Codec::decode_8_components(code, i);
|
|
366
|
+
return _mm256_add_ps(
|
|
367
|
+
_mm256_set1_ps(this->vmin),
|
|
368
|
+
_mm256_mul_ps(xi, _mm256_set1_ps(this->vdiff)));
|
|
392
369
|
}
|
|
393
|
-
|
|
394
370
|
};
|
|
395
371
|
|
|
396
372
|
#endif
|
|
397
373
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
template<class Codec>
|
|
401
|
-
struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
|
|
374
|
+
template <class Codec>
|
|
375
|
+
struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::Quantizer {
|
|
402
376
|
const size_t d;
|
|
403
377
|
const float *vmin, *vdiff;
|
|
404
378
|
|
|
405
|
-
QuantizerTemplate
|
|
406
|
-
|
|
379
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
|
380
|
+
: d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
|
|
407
381
|
|
|
408
382
|
void encode_vector(const float* x, uint8_t* code) const final {
|
|
409
383
|
for (size_t i = 0; i < d; i++) {
|
|
@@ -428,30 +402,25 @@ struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
|
|
|
428
402
|
}
|
|
429
403
|
}
|
|
430
404
|
|
|
431
|
-
float reconstruct_component
|
|
432
|
-
|
|
433
|
-
float xi = Codec::decode_component (code, i);
|
|
405
|
+
float reconstruct_component(const uint8_t* code, int i) const {
|
|
406
|
+
float xi = Codec::decode_component(code, i);
|
|
434
407
|
return vmin[i] + xi * vdiff[i];
|
|
435
408
|
}
|
|
436
|
-
|
|
437
409
|
};
|
|
438
410
|
|
|
439
|
-
|
|
440
411
|
#ifdef __AVX2__
|
|
441
412
|
|
|
442
|
-
template<class Codec>
|
|
443
|
-
struct QuantizerTemplate<Codec, false, 8
|
|
413
|
+
template <class Codec>
|
|
414
|
+
struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
|
|
415
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
|
416
|
+
: QuantizerTemplate<Codec, false, 1>(d, trained) {}
|
|
444
417
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
__m256 xi = Codec::decode_8_components (code, i);
|
|
451
|
-
return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i);
|
|
418
|
+
__m256 reconstruct_8_components(const uint8_t* code, int i) const {
|
|
419
|
+
__m256 xi = Codec::decode_8_components(code, i);
|
|
420
|
+
return _mm256_add_ps(
|
|
421
|
+
_mm256_loadu_ps(this->vmin + i),
|
|
422
|
+
_mm256_mul_ps(xi, _mm256_loadu_ps(this->vdiff + i)));
|
|
452
423
|
}
|
|
453
|
-
|
|
454
|
-
|
|
455
424
|
};
|
|
456
425
|
|
|
457
426
|
#endif
|
|
@@ -460,15 +429,14 @@ struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
|
|
|
460
429
|
* FP16 quantizer
|
|
461
430
|
*******************************************************************/
|
|
462
431
|
|
|
463
|
-
template<int SIMDWIDTH>
|
|
432
|
+
template <int SIMDWIDTH>
|
|
464
433
|
struct QuantizerFP16 {};
|
|
465
434
|
|
|
466
|
-
template<>
|
|
467
|
-
struct QuantizerFP16<1
|
|
435
|
+
template <>
|
|
436
|
+
struct QuantizerFP16<1> : ScalarQuantizer::Quantizer {
|
|
468
437
|
const size_t d;
|
|
469
438
|
|
|
470
|
-
QuantizerFP16(size_t d, const std::vector<float
|
|
471
|
-
d(d) {}
|
|
439
|
+
QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
|
|
472
440
|
|
|
473
441
|
void encode_vector(const float* x, uint8_t* code) const final {
|
|
474
442
|
for (size_t i = 0; i < d; i++) {
|
|
@@ -482,27 +450,22 @@ struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
|
|
|
482
450
|
}
|
|
483
451
|
}
|
|
484
452
|
|
|
485
|
-
float reconstruct_component
|
|
486
|
-
{
|
|
453
|
+
float reconstruct_component(const uint8_t* code, int i) const {
|
|
487
454
|
return decode_fp16(((uint16_t*)code)[i]);
|
|
488
455
|
}
|
|
489
|
-
|
|
490
456
|
};
|
|
491
457
|
|
|
492
458
|
#ifdef USE_F16C
|
|
493
459
|
|
|
494
|
-
template<>
|
|
495
|
-
struct QuantizerFP16<8
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
QuantizerFP16<1> (d, trained) {}
|
|
460
|
+
template <>
|
|
461
|
+
struct QuantizerFP16<8> : QuantizerFP16<1> {
|
|
462
|
+
QuantizerFP16(size_t d, const std::vector<float>& trained)
|
|
463
|
+
: QuantizerFP16<1>(d, trained) {}
|
|
499
464
|
|
|
500
|
-
__m256 reconstruct_8_components
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
return _mm256_cvtph_ps (codei);
|
|
465
|
+
__m256 reconstruct_8_components(const uint8_t* code, int i) const {
|
|
466
|
+
__m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
|
|
467
|
+
return _mm256_cvtph_ps(codei);
|
|
504
468
|
}
|
|
505
|
-
|
|
506
469
|
};
|
|
507
470
|
|
|
508
471
|
#endif
|
|
@@ -511,16 +474,15 @@ struct QuantizerFP16<8>: QuantizerFP16<1> {
|
|
|
511
474
|
* 8bit_direct quantizer
|
|
512
475
|
*******************************************************************/
|
|
513
476
|
|
|
514
|
-
template<int SIMDWIDTH>
|
|
477
|
+
template <int SIMDWIDTH>
|
|
515
478
|
struct Quantizer8bitDirect {};
|
|
516
479
|
|
|
517
|
-
template<>
|
|
518
|
-
struct Quantizer8bitDirect<1
|
|
480
|
+
template <>
|
|
481
|
+
struct Quantizer8bitDirect<1> : ScalarQuantizer::Quantizer {
|
|
519
482
|
const size_t d;
|
|
520
483
|
|
|
521
|
-
Quantizer8bitDirect(size_t d, const std::vector<float
|
|
522
|
-
|
|
523
|
-
|
|
484
|
+
Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
|
|
485
|
+
: d(d) {}
|
|
524
486
|
|
|
525
487
|
void encode_vector(const float* x, uint8_t* code) const final {
|
|
526
488
|
for (size_t i = 0; i < d; i++) {
|
|
@@ -534,82 +496,83 @@ struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
|
|
|
534
496
|
}
|
|
535
497
|
}
|
|
536
498
|
|
|
537
|
-
float reconstruct_component
|
|
538
|
-
{
|
|
499
|
+
float reconstruct_component(const uint8_t* code, int i) const {
|
|
539
500
|
return code[i];
|
|
540
501
|
}
|
|
541
|
-
|
|
542
502
|
};
|
|
543
503
|
|
|
544
504
|
#ifdef __AVX2__
|
|
545
505
|
|
|
546
|
-
template<>
|
|
547
|
-
struct Quantizer8bitDirect<8
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
Quantizer8bitDirect<1> (d, trained) {}
|
|
506
|
+
template <>
|
|
507
|
+
struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
|
|
508
|
+
Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
|
|
509
|
+
: Quantizer8bitDirect<1>(d, trained) {}
|
|
551
510
|
|
|
552
|
-
__m256 reconstruct_8_components
|
|
553
|
-
{
|
|
511
|
+
__m256 reconstruct_8_components(const uint8_t* code, int i) const {
|
|
554
512
|
__m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
|
|
555
|
-
__m256i y8 = _mm256_cvtepu8_epi32
|
|
556
|
-
return _mm256_cvtepi32_ps
|
|
513
|
+
__m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
|
|
514
|
+
return _mm256_cvtepi32_ps(y8); // 8 * float32
|
|
557
515
|
}
|
|
558
|
-
|
|
559
516
|
};
|
|
560
517
|
|
|
561
518
|
#endif
|
|
562
519
|
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
{
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
520
|
+
template <int SIMDWIDTH>
|
|
521
|
+
ScalarQuantizer::Quantizer* select_quantizer_1(
|
|
522
|
+
QuantizerType qtype,
|
|
523
|
+
size_t d,
|
|
524
|
+
const std::vector<float>& trained) {
|
|
525
|
+
switch (qtype) {
|
|
526
|
+
case ScalarQuantizer::QT_8bit:
|
|
527
|
+
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
|
|
528
|
+
d, trained);
|
|
529
|
+
case ScalarQuantizer::QT_6bit:
|
|
530
|
+
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
|
|
531
|
+
d, trained);
|
|
532
|
+
case ScalarQuantizer::QT_4bit:
|
|
533
|
+
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
|
|
534
|
+
d, trained);
|
|
535
|
+
case ScalarQuantizer::QT_8bit_uniform:
|
|
536
|
+
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
|
|
537
|
+
d, trained);
|
|
538
|
+
case ScalarQuantizer::QT_4bit_uniform:
|
|
539
|
+
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
|
|
540
|
+
d, trained);
|
|
541
|
+
case ScalarQuantizer::QT_fp16:
|
|
542
|
+
return new QuantizerFP16<SIMDWIDTH>(d, trained);
|
|
543
|
+
case ScalarQuantizer::QT_8bit_direct:
|
|
544
|
+
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
|
|
545
|
+
}
|
|
546
|
+
FAISS_THROW_MSG("unknown qtype");
|
|
586
547
|
}
|
|
587
548
|
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
549
|
/*******************************************************************
|
|
592
550
|
* Quantizer range training
|
|
593
551
|
*/
|
|
594
552
|
|
|
595
|
-
static float sqr
|
|
553
|
+
static float sqr(float x) {
|
|
596
554
|
return x * x;
|
|
597
555
|
}
|
|
598
556
|
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
557
|
+
void train_Uniform(
|
|
558
|
+
RangeStat rs,
|
|
559
|
+
float rs_arg,
|
|
560
|
+
idx_t n,
|
|
561
|
+
int k,
|
|
562
|
+
const float* x,
|
|
563
|
+
std::vector<float>& trained) {
|
|
564
|
+
trained.resize(2);
|
|
565
|
+
float& vmin = trained[0];
|
|
566
|
+
float& vmax = trained[1];
|
|
607
567
|
|
|
608
568
|
if (rs == ScalarQuantizer::RS_minmax) {
|
|
609
|
-
vmin = HUGE_VAL;
|
|
569
|
+
vmin = HUGE_VAL;
|
|
570
|
+
vmax = -HUGE_VAL;
|
|
610
571
|
for (size_t i = 0; i < n; i++) {
|
|
611
|
-
if (x[i] < vmin)
|
|
612
|
-
|
|
572
|
+
if (x[i] < vmin)
|
|
573
|
+
vmin = x[i];
|
|
574
|
+
if (x[i] > vmax)
|
|
575
|
+
vmax = x[i];
|
|
613
576
|
}
|
|
614
577
|
float vexp = (vmax - vmin) * rs_arg;
|
|
615
578
|
vmin -= vexp;
|
|
@@ -624,16 +587,18 @@ void train_Uniform(RangeStat rs, float rs_arg,
|
|
|
624
587
|
float var = sum2 / n - mean * mean;
|
|
625
588
|
float std = var <= 0 ? 1.0 : sqrt(var);
|
|
626
589
|
|
|
627
|
-
vmin = mean - std * rs_arg
|
|
628
|
-
vmax = mean + std * rs_arg
|
|
590
|
+
vmin = mean - std * rs_arg;
|
|
591
|
+
vmax = mean + std * rs_arg;
|
|
629
592
|
} else if (rs == ScalarQuantizer::RS_quantiles) {
|
|
630
593
|
std::vector<float> x_copy(n);
|
|
631
594
|
memcpy(x_copy.data(), x, n * sizeof(*x));
|
|
632
595
|
// TODO just do a qucikselect
|
|
633
596
|
std::sort(x_copy.begin(), x_copy.end());
|
|
634
597
|
int o = int(rs_arg * n);
|
|
635
|
-
if (o < 0)
|
|
636
|
-
|
|
598
|
+
if (o < 0)
|
|
599
|
+
o = 0;
|
|
600
|
+
if (o > n - o)
|
|
601
|
+
o = n / 2;
|
|
637
602
|
vmin = x_copy[o];
|
|
638
603
|
vmax = x_copy[n - 1 - o];
|
|
639
604
|
|
|
@@ -643,8 +608,10 @@ void train_Uniform(RangeStat rs, float rs_arg,
|
|
|
643
608
|
{
|
|
644
609
|
vmin = HUGE_VAL, vmax = -HUGE_VAL;
|
|
645
610
|
for (size_t i = 0; i < n; i++) {
|
|
646
|
-
if (x[i] < vmin)
|
|
647
|
-
|
|
611
|
+
if (x[i] < vmin)
|
|
612
|
+
vmin = x[i];
|
|
613
|
+
if (x[i] > vmax)
|
|
614
|
+
vmax = x[i];
|
|
648
615
|
sx += x[i];
|
|
649
616
|
}
|
|
650
617
|
b = vmin;
|
|
@@ -659,62 +626,71 @@ void train_Uniform(RangeStat rs, float rs_arg,
|
|
|
659
626
|
|
|
660
627
|
for (idx_t i = 0; i < n; i++) {
|
|
661
628
|
float xi = x[i];
|
|
662
|
-
float ni = floor
|
|
663
|
-
if (ni < 0)
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
629
|
+
float ni = floor((xi - b) / a + 0.5);
|
|
630
|
+
if (ni < 0)
|
|
631
|
+
ni = 0;
|
|
632
|
+
if (ni >= k)
|
|
633
|
+
ni = k - 1;
|
|
634
|
+
err1 += sqr(xi - (ni * a + b));
|
|
635
|
+
sn += ni;
|
|
667
636
|
sn2 += ni * ni;
|
|
668
637
|
sxn += ni * xi;
|
|
669
638
|
}
|
|
670
639
|
|
|
671
640
|
if (err1 == last_err) {
|
|
672
|
-
iter_last_err
|
|
673
|
-
if (iter_last_err == 16)
|
|
641
|
+
iter_last_err++;
|
|
642
|
+
if (iter_last_err == 16)
|
|
643
|
+
break;
|
|
674
644
|
} else {
|
|
675
645
|
last_err = err1;
|
|
676
646
|
iter_last_err = 0;
|
|
677
647
|
}
|
|
678
648
|
|
|
679
|
-
float det = sqr
|
|
649
|
+
float det = sqr(sn) - sn2 * n;
|
|
680
650
|
|
|
681
651
|
b = (sn * sxn - sn2 * sx) / det;
|
|
682
652
|
a = (sn * sx - n * sxn) / det;
|
|
683
653
|
if (verbose) {
|
|
684
|
-
printf
|
|
654
|
+
printf("it %d, err1=%g \r", it, err1);
|
|
685
655
|
fflush(stdout);
|
|
686
656
|
}
|
|
687
657
|
}
|
|
688
|
-
if (verbose)
|
|
658
|
+
if (verbose)
|
|
659
|
+
printf("\n");
|
|
689
660
|
|
|
690
661
|
vmin = b;
|
|
691
662
|
vmax = b + a * (k - 1);
|
|
692
663
|
|
|
693
664
|
} else {
|
|
694
|
-
FAISS_THROW_MSG
|
|
665
|
+
FAISS_THROW_MSG("Invalid qtype");
|
|
695
666
|
}
|
|
696
667
|
vmax -= vmin;
|
|
697
668
|
}
|
|
698
669
|
|
|
699
|
-
void train_NonUniform(
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
670
|
+
void train_NonUniform(
|
|
671
|
+
RangeStat rs,
|
|
672
|
+
float rs_arg,
|
|
673
|
+
idx_t n,
|
|
674
|
+
int d,
|
|
675
|
+
int k,
|
|
676
|
+
const float* x,
|
|
677
|
+
std::vector<float>& trained) {
|
|
678
|
+
trained.resize(2 * d);
|
|
679
|
+
float* vmin = trained.data();
|
|
680
|
+
float* vmax = trained.data() + d;
|
|
707
681
|
if (rs == ScalarQuantizer::RS_minmax) {
|
|
708
|
-
memcpy
|
|
709
|
-
memcpy
|
|
682
|
+
memcpy(vmin, x, sizeof(*x) * d);
|
|
683
|
+
memcpy(vmax, x, sizeof(*x) * d);
|
|
710
684
|
for (size_t i = 1; i < n; i++) {
|
|
711
|
-
const float
|
|
685
|
+
const float* xi = x + i * d;
|
|
712
686
|
for (size_t j = 0; j < d; j++) {
|
|
713
|
-
if (xi[j] < vmin[j])
|
|
714
|
-
|
|
687
|
+
if (xi[j] < vmin[j])
|
|
688
|
+
vmin[j] = xi[j];
|
|
689
|
+
if (xi[j] > vmax[j])
|
|
690
|
+
vmax[j] = xi[j];
|
|
715
691
|
}
|
|
716
692
|
}
|
|
717
|
-
float
|
|
693
|
+
float* vdiff = vmax;
|
|
718
694
|
for (size_t j = 0; j < d; j++) {
|
|
719
695
|
float vexp = (vmax[j] - vmin[j]) * rs_arg;
|
|
720
696
|
vmin[j] -= vexp;
|
|
@@ -725,7 +701,7 @@ void train_NonUniform(RangeStat rs, float rs_arg,
|
|
|
725
701
|
// transpose
|
|
726
702
|
std::vector<float> xt(n * d);
|
|
727
703
|
for (size_t i = 1; i < n; i++) {
|
|
728
|
-
const float
|
|
704
|
+
const float* xi = x + i * d;
|
|
729
705
|
for (size_t j = 0; j < d; j++) {
|
|
730
706
|
xt[j * n + i] = xi[j];
|
|
731
707
|
}
|
|
@@ -733,108 +709,98 @@ void train_NonUniform(RangeStat rs, float rs_arg,
|
|
|
733
709
|
std::vector<float> trained_d(2);
|
|
734
710
|
#pragma omp parallel for
|
|
735
711
|
for (int j = 0; j < d; j++) {
|
|
736
|
-
train_Uniform(rs, rs_arg,
|
|
737
|
-
n, k, xt.data() + j * n,
|
|
738
|
-
trained_d);
|
|
712
|
+
train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d);
|
|
739
713
|
vmin[j] = trained_d[0];
|
|
740
714
|
vmax[j] = trained_d[1];
|
|
741
715
|
}
|
|
742
716
|
}
|
|
743
717
|
}
|
|
744
718
|
|
|
745
|
-
|
|
746
|
-
|
|
747
719
|
/*******************************************************************
|
|
748
720
|
* Similarity: gets vector components and computes a similarity wrt. a
|
|
749
721
|
* query vector stored in the object. The data fields just encapsulate
|
|
750
722
|
* an accumulator.
|
|
751
723
|
*/
|
|
752
724
|
|
|
753
|
-
template<int SIMDWIDTH>
|
|
725
|
+
template <int SIMDWIDTH>
|
|
754
726
|
struct SimilarityL2 {};
|
|
755
727
|
|
|
756
|
-
|
|
757
|
-
template<>
|
|
728
|
+
template <>
|
|
758
729
|
struct SimilarityL2<1> {
|
|
759
730
|
static constexpr int simdwidth = 1;
|
|
760
731
|
static constexpr MetricType metric_type = METRIC_L2;
|
|
761
732
|
|
|
762
733
|
const float *y, *yi;
|
|
763
734
|
|
|
764
|
-
explicit SimilarityL2
|
|
735
|
+
explicit SimilarityL2(const float* y) : y(y) {}
|
|
765
736
|
|
|
766
737
|
/******* scalar accumulator *******/
|
|
767
738
|
|
|
768
739
|
float accu;
|
|
769
740
|
|
|
770
|
-
void begin
|
|
741
|
+
void begin() {
|
|
771
742
|
accu = 0;
|
|
772
743
|
yi = y;
|
|
773
744
|
}
|
|
774
745
|
|
|
775
|
-
void add_component
|
|
746
|
+
void add_component(float x) {
|
|
776
747
|
float tmp = *yi++ - x;
|
|
777
748
|
accu += tmp * tmp;
|
|
778
749
|
}
|
|
779
750
|
|
|
780
|
-
void add_component_2
|
|
751
|
+
void add_component_2(float x1, float x2) {
|
|
781
752
|
float tmp = x1 - x2;
|
|
782
753
|
accu += tmp * tmp;
|
|
783
754
|
}
|
|
784
755
|
|
|
785
|
-
float result
|
|
756
|
+
float result() {
|
|
786
757
|
return accu;
|
|
787
758
|
}
|
|
788
759
|
};
|
|
789
760
|
|
|
790
|
-
|
|
791
761
|
#ifdef __AVX2__
|
|
792
|
-
template<>
|
|
762
|
+
template <>
|
|
793
763
|
struct SimilarityL2<8> {
|
|
794
764
|
static constexpr int simdwidth = 8;
|
|
795
765
|
static constexpr MetricType metric_type = METRIC_L2;
|
|
796
766
|
|
|
797
767
|
const float *y, *yi;
|
|
798
768
|
|
|
799
|
-
explicit SimilarityL2
|
|
769
|
+
explicit SimilarityL2(const float* y) : y(y) {}
|
|
800
770
|
__m256 accu8;
|
|
801
771
|
|
|
802
|
-
void begin_8
|
|
772
|
+
void begin_8() {
|
|
803
773
|
accu8 = _mm256_setzero_ps();
|
|
804
774
|
yi = y;
|
|
805
775
|
}
|
|
806
776
|
|
|
807
|
-
void add_8_components
|
|
808
|
-
__m256 yiv = _mm256_loadu_ps
|
|
777
|
+
void add_8_components(__m256 x) {
|
|
778
|
+
__m256 yiv = _mm256_loadu_ps(yi);
|
|
809
779
|
yi += 8;
|
|
810
|
-
__m256 tmp = yiv
|
|
811
|
-
accu8
|
|
780
|
+
__m256 tmp = _mm256_sub_ps(yiv, x);
|
|
781
|
+
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
|
|
812
782
|
}
|
|
813
783
|
|
|
814
|
-
void add_8_components_2
|
|
815
|
-
__m256 tmp = y
|
|
816
|
-
accu8
|
|
784
|
+
void add_8_components_2(__m256 x, __m256 y) {
|
|
785
|
+
__m256 tmp = _mm256_sub_ps(y, x);
|
|
786
|
+
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
|
|
817
787
|
}
|
|
818
788
|
|
|
819
|
-
float result_8
|
|
789
|
+
float result_8() {
|
|
820
790
|
__m256 sum = _mm256_hadd_ps(accu8, accu8);
|
|
821
791
|
__m256 sum2 = _mm256_hadd_ps(sum, sum);
|
|
822
792
|
// now add the 0th and 4th component
|
|
823
|
-
return
|
|
824
|
-
|
|
825
|
-
_mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
|
|
793
|
+
return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
|
|
794
|
+
_mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
|
|
826
795
|
}
|
|
827
|
-
|
|
828
796
|
};
|
|
829
797
|
|
|
830
798
|
#endif
|
|
831
799
|
|
|
832
|
-
|
|
833
|
-
template<int SIMDWIDTH>
|
|
800
|
+
template <int SIMDWIDTH>
|
|
834
801
|
struct SimilarityIP {};
|
|
835
802
|
|
|
836
|
-
|
|
837
|
-
template<>
|
|
803
|
+
template <>
|
|
838
804
|
struct SimilarityIP<1> {
|
|
839
805
|
static constexpr int simdwidth = 1;
|
|
840
806
|
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
|
@@ -842,30 +808,29 @@ struct SimilarityIP<1> {
|
|
|
842
808
|
|
|
843
809
|
float accu;
|
|
844
810
|
|
|
845
|
-
explicit SimilarityIP
|
|
846
|
-
y (y) {}
|
|
811
|
+
explicit SimilarityIP(const float* y) : y(y) {}
|
|
847
812
|
|
|
848
|
-
void begin
|
|
813
|
+
void begin() {
|
|
849
814
|
accu = 0;
|
|
850
815
|
yi = y;
|
|
851
816
|
}
|
|
852
817
|
|
|
853
|
-
void add_component
|
|
854
|
-
accu +=
|
|
818
|
+
void add_component(float x) {
|
|
819
|
+
accu += *yi++ * x;
|
|
855
820
|
}
|
|
856
821
|
|
|
857
|
-
void add_component_2
|
|
858
|
-
accu +=
|
|
822
|
+
void add_component_2(float x1, float x2) {
|
|
823
|
+
accu += x1 * x2;
|
|
859
824
|
}
|
|
860
825
|
|
|
861
|
-
float result
|
|
826
|
+
float result() {
|
|
862
827
|
return accu;
|
|
863
828
|
}
|
|
864
829
|
};
|
|
865
830
|
|
|
866
831
|
#ifdef __AVX2__
|
|
867
832
|
|
|
868
|
-
template<>
|
|
833
|
+
template <>
|
|
869
834
|
struct SimilarityIP<8> {
|
|
870
835
|
static constexpr int simdwidth = 8;
|
|
871
836
|
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
|
@@ -874,59 +839,53 @@ struct SimilarityIP<8> {
|
|
|
874
839
|
|
|
875
840
|
float accu;
|
|
876
841
|
|
|
877
|
-
explicit SimilarityIP
|
|
878
|
-
y (y) {}
|
|
842
|
+
explicit SimilarityIP(const float* y) : y(y) {}
|
|
879
843
|
|
|
880
844
|
__m256 accu8;
|
|
881
845
|
|
|
882
|
-
void begin_8
|
|
846
|
+
void begin_8() {
|
|
883
847
|
accu8 = _mm256_setzero_ps();
|
|
884
848
|
yi = y;
|
|
885
849
|
}
|
|
886
850
|
|
|
887
|
-
void add_8_components
|
|
888
|
-
__m256 yiv = _mm256_loadu_ps
|
|
851
|
+
void add_8_components(__m256 x) {
|
|
852
|
+
__m256 yiv = _mm256_loadu_ps(yi);
|
|
889
853
|
yi += 8;
|
|
890
|
-
accu8
|
|
854
|
+
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(yiv, x));
|
|
891
855
|
}
|
|
892
856
|
|
|
893
|
-
void add_8_components_2
|
|
894
|
-
accu8
|
|
857
|
+
void add_8_components_2(__m256 x1, __m256 x2) {
|
|
858
|
+
accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(x1, x2));
|
|
895
859
|
}
|
|
896
860
|
|
|
897
|
-
float result_8
|
|
861
|
+
float result_8() {
|
|
898
862
|
__m256 sum = _mm256_hadd_ps(accu8, accu8);
|
|
899
863
|
__m256 sum2 = _mm256_hadd_ps(sum, sum);
|
|
900
864
|
// now add the 0th and 4th component
|
|
901
|
-
return
|
|
902
|
-
|
|
903
|
-
_mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
|
|
865
|
+
return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
|
|
866
|
+
_mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
|
|
904
867
|
}
|
|
905
868
|
};
|
|
906
869
|
#endif
|
|
907
870
|
|
|
908
|
-
|
|
909
871
|
/*******************************************************************
|
|
910
872
|
* DistanceComputer: combines a similarity and a quantizer to do
|
|
911
873
|
* code-to-vector or code-to-code comparisons
|
|
912
874
|
*******************************************************************/
|
|
913
875
|
|
|
914
|
-
template<class Quantizer, class Similarity, int SIMDWIDTH>
|
|
876
|
+
template <class Quantizer, class Similarity, int SIMDWIDTH>
|
|
915
877
|
struct DCTemplate : SQDistanceComputer {};
|
|
916
878
|
|
|
917
|
-
template<class Quantizer, class Similarity>
|
|
918
|
-
struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
|
|
919
|
-
{
|
|
879
|
+
template <class Quantizer, class Similarity>
|
|
880
|
+
struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
|
|
920
881
|
using Sim = Similarity;
|
|
921
882
|
|
|
922
883
|
Quantizer quant;
|
|
923
884
|
|
|
924
|
-
DCTemplate(size_t d, const std::vector<float
|
|
925
|
-
|
|
926
|
-
{}
|
|
885
|
+
DCTemplate(size_t d, const std::vector<float>& trained)
|
|
886
|
+
: quant(d, trained) {}
|
|
927
887
|
|
|
928
888
|
float compute_distance(const float* x, const uint8_t* code) const {
|
|
929
|
-
|
|
930
889
|
Similarity sim(x);
|
|
931
890
|
sim.begin();
|
|
932
891
|
for (size_t i = 0; i < quant.d; i++) {
|
|
@@ -937,7 +896,7 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
|
|
|
937
896
|
}
|
|
938
897
|
|
|
939
898
|
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
940
|
-
|
|
899
|
+
const {
|
|
941
900
|
Similarity sim(nullptr);
|
|
942
901
|
sim.begin();
|
|
943
902
|
for (size_t i = 0; i < quant.d; i++) {
|
|
@@ -948,41 +907,37 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
|
|
|
948
907
|
return sim.result();
|
|
949
908
|
}
|
|
950
909
|
|
|
951
|
-
void set_query
|
|
910
|
+
void set_query(const float* x) final {
|
|
952
911
|
q = x;
|
|
953
912
|
}
|
|
954
913
|
|
|
955
914
|
/// compute distance of vector i to current query
|
|
956
|
-
float operator
|
|
957
|
-
return
|
|
915
|
+
float operator()(idx_t i) final {
|
|
916
|
+
return query_to_code(codes + i * code_size);
|
|
958
917
|
}
|
|
959
918
|
|
|
960
|
-
float symmetric_dis
|
|
961
|
-
return compute_code_distance
|
|
962
|
-
|
|
919
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
920
|
+
return compute_code_distance(
|
|
921
|
+
codes + i * code_size, codes + j * code_size);
|
|
963
922
|
}
|
|
964
923
|
|
|
965
|
-
float query_to_code
|
|
966
|
-
return compute_distance
|
|
924
|
+
float query_to_code(const uint8_t* code) const final {
|
|
925
|
+
return compute_distance(q, code);
|
|
967
926
|
}
|
|
968
|
-
|
|
969
927
|
};
|
|
970
928
|
|
|
971
929
|
#ifdef USE_F16C
|
|
972
930
|
|
|
973
|
-
template<class Quantizer, class Similarity>
|
|
974
|
-
struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
|
|
975
|
-
{
|
|
931
|
+
template <class Quantizer, class Similarity>
|
|
932
|
+
struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
|
|
976
933
|
using Sim = Similarity;
|
|
977
934
|
|
|
978
935
|
Quantizer quant;
|
|
979
936
|
|
|
980
|
-
DCTemplate(size_t d, const std::vector<float
|
|
981
|
-
|
|
982
|
-
{}
|
|
937
|
+
DCTemplate(size_t d, const std::vector<float>& trained)
|
|
938
|
+
: quant(d, trained) {}
|
|
983
939
|
|
|
984
940
|
float compute_distance(const float* x, const uint8_t* code) const {
|
|
985
|
-
|
|
986
941
|
Similarity sim(x);
|
|
987
942
|
sim.begin_8();
|
|
988
943
|
for (size_t i = 0; i < quant.d; i += 8) {
|
|
@@ -993,7 +948,7 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
|
|
|
993
948
|
}
|
|
994
949
|
|
|
995
950
|
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
996
|
-
|
|
951
|
+
const {
|
|
997
952
|
Similarity sim(nullptr);
|
|
998
953
|
sim.begin_8();
|
|
999
954
|
for (size_t i = 0; i < quant.d; i += 8) {
|
|
@@ -1004,49 +959,45 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
|
|
|
1004
959
|
return sim.result_8();
|
|
1005
960
|
}
|
|
1006
961
|
|
|
1007
|
-
void set_query
|
|
962
|
+
void set_query(const float* x) final {
|
|
1008
963
|
q = x;
|
|
1009
964
|
}
|
|
1010
965
|
|
|
1011
966
|
/// compute distance of vector i to current query
|
|
1012
|
-
float operator
|
|
1013
|
-
return
|
|
967
|
+
float operator()(idx_t i) final {
|
|
968
|
+
return query_to_code(codes + i * code_size);
|
|
1014
969
|
}
|
|
1015
970
|
|
|
1016
|
-
float symmetric_dis
|
|
1017
|
-
return compute_code_distance
|
|
1018
|
-
|
|
971
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
972
|
+
return compute_code_distance(
|
|
973
|
+
codes + i * code_size, codes + j * code_size);
|
|
1019
974
|
}
|
|
1020
975
|
|
|
1021
|
-
float query_to_code
|
|
1022
|
-
return compute_distance
|
|
976
|
+
float query_to_code(const uint8_t* code) const final {
|
|
977
|
+
return compute_distance(q, code);
|
|
1023
978
|
}
|
|
1024
|
-
|
|
1025
979
|
};
|
|
1026
980
|
|
|
1027
981
|
#endif
|
|
1028
982
|
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
983
|
/*******************************************************************
|
|
1032
984
|
* DistanceComputerByte: computes distances in the integer domain
|
|
1033
985
|
*******************************************************************/
|
|
1034
986
|
|
|
1035
|
-
template<class Similarity, int SIMDWIDTH>
|
|
987
|
+
template <class Similarity, int SIMDWIDTH>
|
|
1036
988
|
struct DistanceComputerByte : SQDistanceComputer {};
|
|
1037
989
|
|
|
1038
|
-
template<class Similarity>
|
|
990
|
+
template <class Similarity>
|
|
1039
991
|
struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
|
|
1040
992
|
using Sim = Similarity;
|
|
1041
993
|
|
|
1042
994
|
int d;
|
|
1043
995
|
std::vector<uint8_t> tmp;
|
|
1044
996
|
|
|
1045
|
-
DistanceComputerByte(int d, const std::vector<float
|
|
1046
|
-
}
|
|
997
|
+
DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
|
|
1047
998
|
|
|
1048
999
|
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
1049
|
-
|
|
1000
|
+
const {
|
|
1050
1001
|
int accu = 0;
|
|
1051
1002
|
for (int i = 0; i < d; i++) {
|
|
1052
1003
|
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
@@ -1059,7 +1010,7 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
|
|
|
1059
1010
|
return accu;
|
|
1060
1011
|
}
|
|
1061
1012
|
|
|
1062
|
-
void set_query
|
|
1013
|
+
void set_query(const float* x) final {
|
|
1063
1014
|
for (int i = 0; i < d; i++) {
|
|
1064
1015
|
tmp[i] = int(x[i]);
|
|
1065
1016
|
}
|
|
@@ -1071,44 +1022,41 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
|
|
|
1071
1022
|
}
|
|
1072
1023
|
|
|
1073
1024
|
/// compute distance of vector i to current query
|
|
1074
|
-
float operator
|
|
1075
|
-
return
|
|
1025
|
+
float operator()(idx_t i) final {
|
|
1026
|
+
return query_to_code(codes + i * code_size);
|
|
1076
1027
|
}
|
|
1077
1028
|
|
|
1078
|
-
float symmetric_dis
|
|
1079
|
-
return compute_code_distance
|
|
1080
|
-
|
|
1029
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
1030
|
+
return compute_code_distance(
|
|
1031
|
+
codes + i * code_size, codes + j * code_size);
|
|
1081
1032
|
}
|
|
1082
1033
|
|
|
1083
|
-
float query_to_code
|
|
1084
|
-
return compute_code_distance
|
|
1034
|
+
float query_to_code(const uint8_t* code) const final {
|
|
1035
|
+
return compute_code_distance(tmp.data(), code);
|
|
1085
1036
|
}
|
|
1086
|
-
|
|
1087
1037
|
};
|
|
1088
1038
|
|
|
1089
1039
|
#ifdef __AVX2__
|
|
1090
1040
|
|
|
1091
|
-
|
|
1092
|
-
template<class Similarity>
|
|
1041
|
+
template <class Similarity>
|
|
1093
1042
|
struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
1094
1043
|
using Sim = Similarity;
|
|
1095
1044
|
|
|
1096
1045
|
int d;
|
|
1097
1046
|
std::vector<uint8_t> tmp;
|
|
1098
1047
|
|
|
1099
|
-
DistanceComputerByte(int d, const std::vector<float
|
|
1100
|
-
}
|
|
1048
|
+
DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
|
|
1101
1049
|
|
|
1102
1050
|
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
1103
|
-
|
|
1051
|
+
const {
|
|
1104
1052
|
// __m256i accu = _mm256_setzero_ps ();
|
|
1105
|
-
__m256i accu = _mm256_setzero_si256
|
|
1053
|
+
__m256i accu = _mm256_setzero_si256();
|
|
1106
1054
|
for (int i = 0; i < d; i += 16) {
|
|
1107
1055
|
// load 16 bytes, convert to 16 uint16_t
|
|
1108
|
-
__m256i c1 = _mm256_cvtepu8_epi16
|
|
1109
|
-
|
|
1110
|
-
__m256i c2 = _mm256_cvtepu8_epi16
|
|
1111
|
-
|
|
1056
|
+
__m256i c1 = _mm256_cvtepu8_epi16(
|
|
1057
|
+
_mm_loadu_si128((__m128i*)(code1 + i)));
|
|
1058
|
+
__m256i c2 = _mm256_cvtepu8_epi16(
|
|
1059
|
+
_mm_loadu_si128((__m128i*)(code2 + i)));
|
|
1112
1060
|
__m256i prod32;
|
|
1113
1061
|
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
1114
1062
|
prod32 = _mm256_madd_epi16(c1, c2);
|
|
@@ -1116,17 +1064,16 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
|
1116
1064
|
__m256i diff = _mm256_sub_epi16(c1, c2);
|
|
1117
1065
|
prod32 = _mm256_madd_epi16(diff, diff);
|
|
1118
1066
|
}
|
|
1119
|
-
accu = _mm256_add_epi32
|
|
1120
|
-
|
|
1067
|
+
accu = _mm256_add_epi32(accu, prod32);
|
|
1121
1068
|
}
|
|
1122
1069
|
__m128i sum = _mm256_extractf128_si256(accu, 0);
|
|
1123
|
-
sum = _mm_add_epi32
|
|
1124
|
-
sum = _mm_hadd_epi32
|
|
1125
|
-
sum = _mm_hadd_epi32
|
|
1126
|
-
return _mm_cvtsi128_si32
|
|
1070
|
+
sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1));
|
|
1071
|
+
sum = _mm_hadd_epi32(sum, sum);
|
|
1072
|
+
sum = _mm_hadd_epi32(sum, sum);
|
|
1073
|
+
return _mm_cvtsi128_si32(sum);
|
|
1127
1074
|
}
|
|
1128
1075
|
|
|
1129
|
-
void set_query
|
|
1076
|
+
void set_query(const float* x) final {
|
|
1130
1077
|
/*
|
|
1131
1078
|
for (int i = 0; i < d; i += 8) {
|
|
1132
1079
|
__m256 xi = _mm256_loadu_ps (x + i);
|
|
@@ -1143,20 +1090,18 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
|
1143
1090
|
}
|
|
1144
1091
|
|
|
1145
1092
|
/// compute distance of vector i to current query
|
|
1146
|
-
float operator
|
|
1147
|
-
return
|
|
1093
|
+
float operator()(idx_t i) final {
|
|
1094
|
+
return query_to_code(codes + i * code_size);
|
|
1148
1095
|
}
|
|
1149
1096
|
|
|
1150
|
-
float symmetric_dis
|
|
1151
|
-
return compute_code_distance
|
|
1152
|
-
|
|
1097
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
1098
|
+
return compute_code_distance(
|
|
1099
|
+
codes + i * code_size, codes + j * code_size);
|
|
1153
1100
|
}
|
|
1154
1101
|
|
|
1155
|
-
float query_to_code
|
|
1156
|
-
return compute_code_distance
|
|
1102
|
+
float query_to_code(const uint8_t* code) const final {
|
|
1103
|
+
return compute_code_distance(tmp.data(), code);
|
|
1157
1104
|
}
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
1105
|
};
|
|
1161
1106
|
|
|
1162
1107
|
#endif
|
|
@@ -1166,215 +1111,218 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
|
1166
1111
|
* specialization
|
|
1167
1112
|
*******************************************************************/
|
|
1168
1113
|
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
{
|
|
1114
|
+
template <class Sim>
|
|
1115
|
+
SQDistanceComputer* select_distance_computer(
|
|
1116
|
+
QuantizerType qtype,
|
|
1117
|
+
size_t d,
|
|
1118
|
+
const std::vector<float>& trained) {
|
|
1175
1119
|
constexpr int SIMDWIDTH = Sim::simdwidth;
|
|
1176
|
-
switch(qtype) {
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1120
|
+
switch (qtype) {
|
|
1121
|
+
case ScalarQuantizer::QT_8bit_uniform:
|
|
1122
|
+
return new DCTemplate<
|
|
1123
|
+
QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
|
|
1124
|
+
Sim,
|
|
1125
|
+
SIMDWIDTH>(d, trained);
|
|
1126
|
+
|
|
1127
|
+
case ScalarQuantizer::QT_4bit_uniform:
|
|
1128
|
+
return new DCTemplate<
|
|
1129
|
+
QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
|
|
1130
|
+
Sim,
|
|
1131
|
+
SIMDWIDTH>(d, trained);
|
|
1132
|
+
|
|
1133
|
+
case ScalarQuantizer::QT_8bit:
|
|
1134
|
+
return new DCTemplate<
|
|
1135
|
+
QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
|
|
1136
|
+
Sim,
|
|
1137
|
+
SIMDWIDTH>(d, trained);
|
|
1138
|
+
|
|
1139
|
+
case ScalarQuantizer::QT_6bit:
|
|
1140
|
+
return new DCTemplate<
|
|
1141
|
+
QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
|
|
1142
|
+
Sim,
|
|
1143
|
+
SIMDWIDTH>(d, trained);
|
|
1144
|
+
|
|
1145
|
+
case ScalarQuantizer::QT_4bit:
|
|
1146
|
+
return new DCTemplate<
|
|
1147
|
+
QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
|
|
1148
|
+
Sim,
|
|
1149
|
+
SIMDWIDTH>(d, trained);
|
|
1150
|
+
|
|
1151
|
+
case ScalarQuantizer::QT_fp16:
|
|
1152
|
+
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
|
|
1153
|
+
d, trained);
|
|
1154
|
+
|
|
1155
|
+
case ScalarQuantizer::QT_8bit_direct:
|
|
1156
|
+
if (d % 16 == 0) {
|
|
1157
|
+
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
|
|
1158
|
+
} else {
|
|
1159
|
+
return new DCTemplate<
|
|
1160
|
+
Quantizer8bitDirect<SIMDWIDTH>,
|
|
1161
|
+
Sim,
|
|
1162
|
+
SIMDWIDTH>(d, trained);
|
|
1163
|
+
}
|
|
1208
1164
|
}
|
|
1209
|
-
FAISS_THROW_MSG
|
|
1165
|
+
FAISS_THROW_MSG("unknown qtype");
|
|
1210
1166
|
return nullptr;
|
|
1211
1167
|
}
|
|
1212
1168
|
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
1169
|
} // anonymous namespace
|
|
1216
1170
|
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
1171
|
/*******************************************************************
|
|
1220
1172
|
* ScalarQuantizer implementation
|
|
1221
1173
|
********************************************************************/
|
|
1222
1174
|
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
(size_t d, QuantizerType qtype):
|
|
1227
|
-
qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d(d)
|
|
1228
|
-
{
|
|
1229
|
-
set_derived_sizes();
|
|
1175
|
+
ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
|
|
1176
|
+
: qtype(qtype), rangestat(RS_minmax), rangestat_arg(0), d(d) {
|
|
1177
|
+
set_derived_sizes();
|
|
1230
1178
|
}
|
|
1231
1179
|
|
|
1232
|
-
ScalarQuantizer::ScalarQuantizer
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1180
|
+
ScalarQuantizer::ScalarQuantizer()
|
|
1181
|
+
: qtype(QT_8bit),
|
|
1182
|
+
rangestat(RS_minmax),
|
|
1183
|
+
rangestat_arg(0),
|
|
1184
|
+
d(0),
|
|
1185
|
+
bits(0),
|
|
1186
|
+
code_size(0) {}
|
|
1236
1187
|
|
|
1237
|
-
void ScalarQuantizer::set_derived_sizes
|
|
1238
|
-
{
|
|
1188
|
+
void ScalarQuantizer::set_derived_sizes() {
|
|
1239
1189
|
switch (qtype) {
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1190
|
+
case QT_8bit:
|
|
1191
|
+
case QT_8bit_uniform:
|
|
1192
|
+
case QT_8bit_direct:
|
|
1193
|
+
code_size = d;
|
|
1194
|
+
bits = 8;
|
|
1195
|
+
break;
|
|
1196
|
+
case QT_4bit:
|
|
1197
|
+
case QT_4bit_uniform:
|
|
1198
|
+
code_size = (d + 1) / 2;
|
|
1199
|
+
bits = 4;
|
|
1200
|
+
break;
|
|
1201
|
+
case QT_6bit:
|
|
1202
|
+
code_size = (d * 6 + 7) / 8;
|
|
1203
|
+
bits = 6;
|
|
1204
|
+
break;
|
|
1205
|
+
case QT_fp16:
|
|
1206
|
+
code_size = d * 2;
|
|
1207
|
+
bits = 16;
|
|
1208
|
+
break;
|
|
1259
1209
|
}
|
|
1260
1210
|
}
|
|
1261
1211
|
|
|
1262
|
-
void ScalarQuantizer::train
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
qtype == QT_8bit ? 8 : -1;
|
|
1212
|
+
void ScalarQuantizer::train(size_t n, const float* x) {
|
|
1213
|
+
int bit_per_dim = qtype == QT_4bit_uniform ? 4
|
|
1214
|
+
: qtype == QT_4bit ? 4
|
|
1215
|
+
: qtype == QT_6bit ? 6
|
|
1216
|
+
: qtype == QT_8bit_uniform ? 8
|
|
1217
|
+
: qtype == QT_8bit ? 8
|
|
1218
|
+
: -1;
|
|
1270
1219
|
|
|
1271
1220
|
switch (qtype) {
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1221
|
+
case QT_4bit_uniform:
|
|
1222
|
+
case QT_8bit_uniform:
|
|
1223
|
+
train_Uniform(
|
|
1224
|
+
rangestat,
|
|
1225
|
+
rangestat_arg,
|
|
1226
|
+
n * d,
|
|
1227
|
+
1 << bit_per_dim,
|
|
1228
|
+
x,
|
|
1229
|
+
trained);
|
|
1230
|
+
break;
|
|
1231
|
+
case QT_4bit:
|
|
1232
|
+
case QT_8bit:
|
|
1233
|
+
case QT_6bit:
|
|
1234
|
+
train_NonUniform(
|
|
1235
|
+
rangestat,
|
|
1236
|
+
rangestat_arg,
|
|
1237
|
+
n,
|
|
1238
|
+
d,
|
|
1239
|
+
1 << bit_per_dim,
|
|
1240
|
+
x,
|
|
1241
|
+
trained);
|
|
1242
|
+
break;
|
|
1243
|
+
case QT_fp16:
|
|
1244
|
+
case QT_8bit_direct:
|
|
1245
|
+
// no training necessary
|
|
1246
|
+
break;
|
|
1284
1247
|
}
|
|
1285
1248
|
}
|
|
1286
1249
|
|
|
1287
|
-
void ScalarQuantizer::train_residual(
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
{
|
|
1293
|
-
const float
|
|
1250
|
+
void ScalarQuantizer::train_residual(
|
|
1251
|
+
size_t n,
|
|
1252
|
+
const float* x,
|
|
1253
|
+
Index* quantizer,
|
|
1254
|
+
bool by_residual,
|
|
1255
|
+
bool verbose) {
|
|
1256
|
+
const float* x_in = x;
|
|
1294
1257
|
|
|
1295
1258
|
// 100k points more than enough
|
|
1296
|
-
x = fvecs_maybe_subsample (
|
|
1297
|
-
d, (size_t*)&n, 100000,
|
|
1298
|
-
x, verbose, 1234);
|
|
1259
|
+
x = fvecs_maybe_subsample(d, (size_t*)&n, 100000, x, verbose, 1234);
|
|
1299
1260
|
|
|
1300
|
-
ScopeDeleter<float> del_x
|
|
1261
|
+
ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
|
|
1301
1262
|
|
|
1302
1263
|
if (by_residual) {
|
|
1303
1264
|
std::vector<Index::idx_t> idx(n);
|
|
1304
|
-
quantizer->assign
|
|
1265
|
+
quantizer->assign(n, x, idx.data());
|
|
1305
1266
|
|
|
1306
1267
|
std::vector<float> residuals(n * d);
|
|
1307
|
-
quantizer->compute_residual_n
|
|
1268
|
+
quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
|
|
1308
1269
|
|
|
1309
|
-
train
|
|
1270
|
+
train(n, residuals.data());
|
|
1310
1271
|
} else {
|
|
1311
|
-
train
|
|
1272
|
+
train(n, x);
|
|
1312
1273
|
}
|
|
1313
1274
|
}
|
|
1314
1275
|
|
|
1315
|
-
|
|
1316
|
-
ScalarQuantizer::Quantizer *ScalarQuantizer::select_quantizer () const
|
|
1317
|
-
{
|
|
1276
|
+
ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
|
|
1318
1277
|
#ifdef USE_F16C
|
|
1319
1278
|
if (d % 8 == 0) {
|
|
1320
|
-
return select_quantizer_1<8>
|
|
1279
|
+
return select_quantizer_1<8>(qtype, d, trained);
|
|
1321
1280
|
} else
|
|
1322
1281
|
#endif
|
|
1323
1282
|
{
|
|
1324
|
-
return select_quantizer_1<1>
|
|
1283
|
+
return select_quantizer_1<1>(qtype, d, trained);
|
|
1325
1284
|
}
|
|
1326
1285
|
}
|
|
1327
1286
|
|
|
1287
|
+
void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
1288
|
+
const {
|
|
1289
|
+
std::unique_ptr<Quantizer> squant(select_quantizer());
|
|
1328
1290
|
|
|
1329
|
-
|
|
1330
|
-
uint8_t * codes,
|
|
1331
|
-
size_t n) const
|
|
1332
|
-
{
|
|
1333
|
-
std::unique_ptr<Quantizer> squant(select_quantizer ());
|
|
1334
|
-
|
|
1335
|
-
memset (codes, 0, code_size * n);
|
|
1291
|
+
memset(codes, 0, code_size * n);
|
|
1336
1292
|
#pragma omp parallel for
|
|
1337
1293
|
for (int64_t i = 0; i < n; i++)
|
|
1338
|
-
squant->encode_vector
|
|
1294
|
+
squant->encode_vector(x + i * d, codes + i * code_size);
|
|
1339
1295
|
}
|
|
1340
1296
|
|
|
1341
|
-
void ScalarQuantizer::decode
|
|
1342
|
-
|
|
1343
|
-
std::unique_ptr<Quantizer> squant(select_quantizer ());
|
|
1297
|
+
void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
|
|
1298
|
+
std::unique_ptr<Quantizer> squant(select_quantizer());
|
|
1344
1299
|
|
|
1345
1300
|
#pragma omp parallel for
|
|
1346
1301
|
for (int64_t i = 0; i < n; i++)
|
|
1347
|
-
squant->decode_vector
|
|
1302
|
+
squant->decode_vector(codes + i * code_size, x + i * d);
|
|
1348
1303
|
}
|
|
1349
1304
|
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
ScalarQuantizer::get_distance_computer (MetricType metric) const
|
|
1353
|
-
{
|
|
1305
|
+
SQDistanceComputer* ScalarQuantizer::get_distance_computer(
|
|
1306
|
+
MetricType metric) const {
|
|
1354
1307
|
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
|
|
1355
1308
|
#ifdef USE_F16C
|
|
1356
1309
|
if (d % 8 == 0) {
|
|
1357
1310
|
if (metric == METRIC_L2) {
|
|
1358
|
-
return select_distance_computer<SimilarityL2<8
|
|
1359
|
-
(qtype, d, trained);
|
|
1311
|
+
return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
|
|
1360
1312
|
} else {
|
|
1361
|
-
return select_distance_computer<SimilarityIP<8
|
|
1362
|
-
(qtype, d, trained);
|
|
1313
|
+
return select_distance_computer<SimilarityIP<8>>(qtype, d, trained);
|
|
1363
1314
|
}
|
|
1364
1315
|
} else
|
|
1365
1316
|
#endif
|
|
1366
1317
|
{
|
|
1367
1318
|
if (metric == METRIC_L2) {
|
|
1368
|
-
return select_distance_computer<SimilarityL2<1
|
|
1369
|
-
(qtype, d, trained);
|
|
1319
|
+
return select_distance_computer<SimilarityL2<1>>(qtype, d, trained);
|
|
1370
1320
|
} else {
|
|
1371
|
-
return select_distance_computer<SimilarityIP<1
|
|
1372
|
-
(qtype, d, trained);
|
|
1321
|
+
return select_distance_computer<SimilarityIP<1>>(qtype, d, trained);
|
|
1373
1322
|
}
|
|
1374
1323
|
}
|
|
1375
1324
|
}
|
|
1376
1325
|
|
|
1377
|
-
|
|
1378
1326
|
/*******************************************************************
|
|
1379
1327
|
* IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object
|
|
1380
1328
|
*
|
|
@@ -1384,54 +1332,57 @@ ScalarQuantizer::get_distance_computer (MetricType metric) const
|
|
|
1384
1332
|
|
|
1385
1333
|
namespace {
|
|
1386
1334
|
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
struct IVFSQScannerIP: InvertedListScanner {
|
|
1335
|
+
template <class DCClass>
|
|
1336
|
+
struct IVFSQScannerIP : InvertedListScanner {
|
|
1390
1337
|
DCClass dc;
|
|
1391
1338
|
bool store_pairs, by_residual;
|
|
1392
1339
|
|
|
1393
1340
|
size_t code_size;
|
|
1394
1341
|
|
|
1395
|
-
idx_t list_no;
|
|
1396
|
-
float accu0;
|
|
1397
|
-
|
|
1398
|
-
IVFSQScannerIP(int d, const std::vector<float> & trained,
|
|
1399
|
-
size_t code_size, bool store_pairs,
|
|
1400
|
-
bool by_residual):
|
|
1401
|
-
dc(d, trained), store_pairs(store_pairs),
|
|
1402
|
-
by_residual(by_residual),
|
|
1403
|
-
code_size(code_size), list_no(0), accu0(0)
|
|
1404
|
-
{}
|
|
1342
|
+
idx_t list_no; /// current list (set to 0 for Flat index
|
|
1343
|
+
float accu0; /// added to all distances
|
|
1405
1344
|
|
|
1345
|
+
IVFSQScannerIP(
|
|
1346
|
+
int d,
|
|
1347
|
+
const std::vector<float>& trained,
|
|
1348
|
+
size_t code_size,
|
|
1349
|
+
bool store_pairs,
|
|
1350
|
+
bool by_residual)
|
|
1351
|
+
: dc(d, trained),
|
|
1352
|
+
store_pairs(store_pairs),
|
|
1353
|
+
by_residual(by_residual),
|
|
1354
|
+
code_size(code_size),
|
|
1355
|
+
list_no(0),
|
|
1356
|
+
accu0(0) {}
|
|
1406
1357
|
|
|
1407
|
-
void set_query
|
|
1408
|
-
dc.set_query
|
|
1358
|
+
void set_query(const float* query) override {
|
|
1359
|
+
dc.set_query(query);
|
|
1409
1360
|
}
|
|
1410
1361
|
|
|
1411
|
-
void set_list
|
|
1362
|
+
void set_list(idx_t list_no, float coarse_dis) override {
|
|
1412
1363
|
this->list_no = list_no;
|
|
1413
1364
|
accu0 = by_residual ? coarse_dis : 0;
|
|
1414
1365
|
}
|
|
1415
1366
|
|
|
1416
|
-
float distance_to_code
|
|
1417
|
-
return accu0 + dc.query_to_code
|
|
1367
|
+
float distance_to_code(const uint8_t* code) const final {
|
|
1368
|
+
return accu0 + dc.query_to_code(code);
|
|
1418
1369
|
}
|
|
1419
1370
|
|
|
1420
|
-
size_t scan_codes
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1371
|
+
size_t scan_codes(
|
|
1372
|
+
size_t list_size,
|
|
1373
|
+
const uint8_t* codes,
|
|
1374
|
+
const idx_t* ids,
|
|
1375
|
+
float* simi,
|
|
1376
|
+
idx_t* idxi,
|
|
1377
|
+
size_t k) const override {
|
|
1426
1378
|
size_t nup = 0;
|
|
1427
1379
|
|
|
1428
1380
|
for (size_t j = 0; j < list_size; j++) {
|
|
1381
|
+
float accu = accu0 + dc.query_to_code(codes);
|
|
1429
1382
|
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
if (accu > simi [0]) {
|
|
1383
|
+
if (accu > simi[0]) {
|
|
1433
1384
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1434
|
-
minheap_replace_top
|
|
1385
|
+
minheap_replace_top(k, simi, idxi, accu, id);
|
|
1435
1386
|
nup++;
|
|
1436
1387
|
}
|
|
1437
1388
|
codes += code_size;
|
|
@@ -1439,86 +1390,87 @@ struct IVFSQScannerIP: InvertedListScanner {
|
|
|
1439
1390
|
return nup;
|
|
1440
1391
|
}
|
|
1441
1392
|
|
|
1442
|
-
void scan_codes_range
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1393
|
+
void scan_codes_range(
|
|
1394
|
+
size_t list_size,
|
|
1395
|
+
const uint8_t* codes,
|
|
1396
|
+
const idx_t* ids,
|
|
1397
|
+
float radius,
|
|
1398
|
+
RangeQueryResult& res) const override {
|
|
1448
1399
|
for (size_t j = 0; j < list_size; j++) {
|
|
1449
|
-
float accu = accu0 + dc.query_to_code
|
|
1400
|
+
float accu = accu0 + dc.query_to_code(codes);
|
|
1450
1401
|
if (accu > radius) {
|
|
1451
1402
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1452
|
-
res.add
|
|
1403
|
+
res.add(accu, id);
|
|
1453
1404
|
}
|
|
1454
1405
|
codes += code_size;
|
|
1455
1406
|
}
|
|
1456
1407
|
}
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
1408
|
};
|
|
1460
1409
|
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
struct IVFSQScannerL2: InvertedListScanner {
|
|
1464
|
-
|
|
1410
|
+
template <class DCClass>
|
|
1411
|
+
struct IVFSQScannerL2 : InvertedListScanner {
|
|
1465
1412
|
DCClass dc;
|
|
1466
1413
|
|
|
1467
1414
|
bool store_pairs, by_residual;
|
|
1468
1415
|
size_t code_size;
|
|
1469
|
-
const Index
|
|
1470
|
-
idx_t list_no;
|
|
1471
|
-
const float
|
|
1416
|
+
const Index* quantizer;
|
|
1417
|
+
idx_t list_no; /// current inverted list
|
|
1418
|
+
const float* x; /// current query
|
|
1472
1419
|
|
|
1473
1420
|
std::vector<float> tmp;
|
|
1474
1421
|
|
|
1475
|
-
IVFSQScannerL2(
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1422
|
+
IVFSQScannerL2(
|
|
1423
|
+
int d,
|
|
1424
|
+
const std::vector<float>& trained,
|
|
1425
|
+
size_t code_size,
|
|
1426
|
+
const Index* quantizer,
|
|
1427
|
+
bool store_pairs,
|
|
1428
|
+
bool by_residual)
|
|
1429
|
+
: dc(d, trained),
|
|
1430
|
+
store_pairs(store_pairs),
|
|
1431
|
+
by_residual(by_residual),
|
|
1432
|
+
code_size(code_size),
|
|
1433
|
+
quantizer(quantizer),
|
|
1434
|
+
list_no(0),
|
|
1435
|
+
x(nullptr),
|
|
1436
|
+
tmp(d) {}
|
|
1437
|
+
|
|
1438
|
+
void set_query(const float* query) override {
|
|
1486
1439
|
x = query;
|
|
1487
1440
|
if (!quantizer) {
|
|
1488
|
-
dc.set_query
|
|
1441
|
+
dc.set_query(query);
|
|
1489
1442
|
}
|
|
1490
1443
|
}
|
|
1491
1444
|
|
|
1492
|
-
|
|
1493
|
-
void set_list (idx_t list_no, float /*coarse_dis*/) override {
|
|
1445
|
+
void set_list(idx_t list_no, float /*coarse_dis*/) override {
|
|
1494
1446
|
if (by_residual) {
|
|
1495
1447
|
this->list_no = list_no;
|
|
1496
1448
|
// shift of x_in wrt centroid
|
|
1497
|
-
quantizer->compute_residual
|
|
1498
|
-
dc.set_query
|
|
1449
|
+
quantizer->compute_residual(x, tmp.data(), list_no);
|
|
1450
|
+
dc.set_query(tmp.data());
|
|
1499
1451
|
} else {
|
|
1500
|
-
dc.set_query
|
|
1452
|
+
dc.set_query(x);
|
|
1501
1453
|
}
|
|
1502
1454
|
}
|
|
1503
1455
|
|
|
1504
|
-
float distance_to_code
|
|
1505
|
-
return dc.query_to_code
|
|
1456
|
+
float distance_to_code(const uint8_t* code) const final {
|
|
1457
|
+
return dc.query_to_code(code);
|
|
1506
1458
|
}
|
|
1507
1459
|
|
|
1508
|
-
size_t scan_codes
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1460
|
+
size_t scan_codes(
|
|
1461
|
+
size_t list_size,
|
|
1462
|
+
const uint8_t* codes,
|
|
1463
|
+
const idx_t* ids,
|
|
1464
|
+
float* simi,
|
|
1465
|
+
idx_t* idxi,
|
|
1466
|
+
size_t k) const override {
|
|
1514
1467
|
size_t nup = 0;
|
|
1515
1468
|
for (size_t j = 0; j < list_size; j++) {
|
|
1469
|
+
float dis = dc.query_to_code(codes);
|
|
1516
1470
|
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
if (dis < simi [0]) {
|
|
1471
|
+
if (dis < simi[0]) {
|
|
1520
1472
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1521
|
-
maxheap_replace_top
|
|
1473
|
+
maxheap_replace_top(k, simi, idxi, dis, id);
|
|
1522
1474
|
nup++;
|
|
1523
1475
|
}
|
|
1524
1476
|
codes += code_size;
|
|
@@ -1526,137 +1478,132 @@ struct IVFSQScannerL2: InvertedListScanner {
|
|
|
1526
1478
|
return nup;
|
|
1527
1479
|
}
|
|
1528
1480
|
|
|
1529
|
-
void scan_codes_range
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1481
|
+
void scan_codes_range(
|
|
1482
|
+
size_t list_size,
|
|
1483
|
+
const uint8_t* codes,
|
|
1484
|
+
const idx_t* ids,
|
|
1485
|
+
float radius,
|
|
1486
|
+
RangeQueryResult& res) const override {
|
|
1535
1487
|
for (size_t j = 0; j < list_size; j++) {
|
|
1536
|
-
float dis = dc.query_to_code
|
|
1488
|
+
float dis = dc.query_to_code(codes);
|
|
1537
1489
|
if (dis < radius) {
|
|
1538
1490
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1539
|
-
res.add
|
|
1491
|
+
res.add(dis, id);
|
|
1540
1492
|
}
|
|
1541
1493
|
codes += code_size;
|
|
1542
1494
|
}
|
|
1543
1495
|
}
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
1496
|
};
|
|
1547
1497
|
|
|
1548
|
-
template<class DCClass>
|
|
1549
|
-
InvertedListScanner* sel2_InvertedListScanner
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1498
|
+
template <class DCClass>
|
|
1499
|
+
InvertedListScanner* sel2_InvertedListScanner(
|
|
1500
|
+
const ScalarQuantizer* sq,
|
|
1501
|
+
const Index* quantizer,
|
|
1502
|
+
bool store_pairs,
|
|
1503
|
+
bool r) {
|
|
1553
1504
|
if (DCClass::Sim::metric_type == METRIC_L2) {
|
|
1554
|
-
return new IVFSQScannerL2<DCClass>(
|
|
1555
|
-
|
|
1505
|
+
return new IVFSQScannerL2<DCClass>(
|
|
1506
|
+
sq->d, sq->trained, sq->code_size, quantizer, store_pairs, r);
|
|
1556
1507
|
} else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
1557
|
-
return new IVFSQScannerIP<DCClass>(
|
|
1558
|
-
|
|
1508
|
+
return new IVFSQScannerIP<DCClass>(
|
|
1509
|
+
sq->d, sq->trained, sq->code_size, store_pairs, r);
|
|
1559
1510
|
} else {
|
|
1560
1511
|
FAISS_THROW_MSG("unsupported metric type");
|
|
1561
1512
|
}
|
|
1562
1513
|
}
|
|
1563
1514
|
|
|
1564
|
-
template<class Similarity, class Codec, bool uniform>
|
|
1565
|
-
InvertedListScanner* sel12_InvertedListScanner
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1515
|
+
template <class Similarity, class Codec, bool uniform>
|
|
1516
|
+
InvertedListScanner* sel12_InvertedListScanner(
|
|
1517
|
+
const ScalarQuantizer* sq,
|
|
1518
|
+
const Index* quantizer,
|
|
1519
|
+
bool store_pairs,
|
|
1520
|
+
bool r) {
|
|
1569
1521
|
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1570
1522
|
using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
|
|
1571
1523
|
using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
|
|
1572
|
-
return sel2_InvertedListScanner<DCClass>
|
|
1524
|
+
return sel2_InvertedListScanner<DCClass>(sq, quantizer, store_pairs, r);
|
|
1573
1525
|
}
|
|
1574
1526
|
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
{
|
|
1527
|
+
template <class Similarity>
|
|
1528
|
+
InvertedListScanner* sel1_InvertedListScanner(
|
|
1529
|
+
const ScalarQuantizer* sq,
|
|
1530
|
+
const Index* quantizer,
|
|
1531
|
+
bool store_pairs,
|
|
1532
|
+
bool r) {
|
|
1582
1533
|
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1583
|
-
switch(sq->qtype) {
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
<
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
<DCTemplate<
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1534
|
+
switch (sq->qtype) {
|
|
1535
|
+
case ScalarQuantizer::QT_8bit_uniform:
|
|
1536
|
+
return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
|
|
1537
|
+
sq, quantizer, store_pairs, r);
|
|
1538
|
+
case ScalarQuantizer::QT_4bit_uniform:
|
|
1539
|
+
return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
|
|
1540
|
+
sq, quantizer, store_pairs, r);
|
|
1541
|
+
case ScalarQuantizer::QT_8bit:
|
|
1542
|
+
return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
|
|
1543
|
+
sq, quantizer, store_pairs, r);
|
|
1544
|
+
case ScalarQuantizer::QT_4bit:
|
|
1545
|
+
return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
|
|
1546
|
+
sq, quantizer, store_pairs, r);
|
|
1547
|
+
case ScalarQuantizer::QT_6bit:
|
|
1548
|
+
return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
|
|
1549
|
+
sq, quantizer, store_pairs, r);
|
|
1550
|
+
case ScalarQuantizer::QT_fp16:
|
|
1551
|
+
return sel2_InvertedListScanner<DCTemplate<
|
|
1552
|
+
QuantizerFP16<SIMDWIDTH>,
|
|
1553
|
+
Similarity,
|
|
1554
|
+
SIMDWIDTH>>(sq, quantizer, store_pairs, r);
|
|
1555
|
+
case ScalarQuantizer::QT_8bit_direct:
|
|
1556
|
+
if (sq->d % 16 == 0) {
|
|
1557
|
+
return sel2_InvertedListScanner<
|
|
1558
|
+
DistanceComputerByte<Similarity, SIMDWIDTH>>(
|
|
1559
|
+
sq, quantizer, store_pairs, r);
|
|
1560
|
+
} else {
|
|
1561
|
+
return sel2_InvertedListScanner<DCTemplate<
|
|
1562
|
+
Quantizer8bitDirect<SIMDWIDTH>,
|
|
1563
|
+
Similarity,
|
|
1564
|
+
SIMDWIDTH>>(sq, quantizer, store_pairs, r);
|
|
1565
|
+
}
|
|
1615
1566
|
}
|
|
1616
1567
|
|
|
1617
|
-
FAISS_THROW_MSG
|
|
1568
|
+
FAISS_THROW_MSG("unknown qtype");
|
|
1618
1569
|
return nullptr;
|
|
1619
1570
|
}
|
|
1620
1571
|
|
|
1621
|
-
template<int SIMDWIDTH>
|
|
1622
|
-
InvertedListScanner* sel0_InvertedListScanner
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1572
|
+
template <int SIMDWIDTH>
|
|
1573
|
+
InvertedListScanner* sel0_InvertedListScanner(
|
|
1574
|
+
MetricType mt,
|
|
1575
|
+
const ScalarQuantizer* sq,
|
|
1576
|
+
const Index* quantizer,
|
|
1577
|
+
bool store_pairs,
|
|
1578
|
+
bool by_residual) {
|
|
1626
1579
|
if (mt == METRIC_L2) {
|
|
1627
|
-
return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH
|
|
1628
|
-
|
|
1580
|
+
return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
|
|
1581
|
+
sq, quantizer, store_pairs, by_residual);
|
|
1629
1582
|
} else if (mt == METRIC_INNER_PRODUCT) {
|
|
1630
|
-
return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH
|
|
1631
|
-
|
|
1583
|
+
return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
|
|
1584
|
+
sq, quantizer, store_pairs, by_residual);
|
|
1632
1585
|
} else {
|
|
1633
1586
|
FAISS_THROW_MSG("unsupported metric type");
|
|
1634
1587
|
}
|
|
1635
1588
|
}
|
|
1636
1589
|
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
1590
|
} // anonymous namespace
|
|
1640
1591
|
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
{
|
|
1592
|
+
InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
|
|
1593
|
+
MetricType mt,
|
|
1594
|
+
const Index* quantizer,
|
|
1595
|
+
bool store_pairs,
|
|
1596
|
+
bool by_residual) const {
|
|
1646
1597
|
#ifdef USE_F16C
|
|
1647
1598
|
if (d % 8 == 0) {
|
|
1648
|
-
return sel0_InvertedListScanner<8>
|
|
1649
|
-
|
|
1599
|
+
return sel0_InvertedListScanner<8>(
|
|
1600
|
+
mt, this, quantizer, store_pairs, by_residual);
|
|
1650
1601
|
} else
|
|
1651
1602
|
#endif
|
|
1652
1603
|
{
|
|
1653
|
-
return sel0_InvertedListScanner<1>
|
|
1654
|
-
|
|
1604
|
+
return sel0_InvertedListScanner<1>(
|
|
1605
|
+
mt, this, quantizer, store_pairs, by_residual);
|
|
1655
1606
|
}
|
|
1656
1607
|
}
|
|
1657
1608
|
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
1609
|
} // namespace faiss
|