faiss 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +36 -33
- data/vendor/faiss/faiss/AutoTune.h +6 -3
- data/vendor/faiss/faiss/Clustering.cpp +16 -12
- data/vendor/faiss/faiss/Index.cpp +3 -4
- data/vendor/faiss/faiss/Index.h +3 -3
- data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
- data/vendor/faiss/faiss/IndexBinary.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
- data/vendor/faiss/faiss/IndexFlat.h +0 -51
- data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
- data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
- data/vendor/faiss/faiss/IndexIVF.h +22 -15
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
- data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
- data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
- data/vendor/faiss/faiss/IndexRefine.h +73 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
- data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
- data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
- data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
- data/vendor/faiss/faiss/impl/io.cpp +33 -2
- data/vendor/faiss/faiss/impl/io.h +7 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
- data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
- data/vendor/faiss/faiss/index_factory.cpp +112 -7
- data/vendor/faiss/faiss/index_io.h +1 -48
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
- data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
- data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
- data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
- data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
- data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
- data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
- data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
- data/vendor/faiss/faiss/utils/Heap.h +61 -50
- data/vendor/faiss/faiss/utils/distances.cpp +164 -319
- data/vendor/faiss/faiss/utils/distances.h +28 -20
- data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
- data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
- data/vendor/faiss/faiss/utils/hamming.h +2 -7
- data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
- data/vendor/faiss/faiss/utils/partitioning.h +69 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
- data/vendor/faiss/faiss/utils/simdlib.h +31 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
- metadata +43 -141
- data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
- data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
- data/vendor/faiss/c_api/AutoTune_c.h +0 -66
- data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
- data/vendor/faiss/c_api/Clustering_c.h +0 -123
- data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
- data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
- data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
- data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
- data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
- data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
- data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
- data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
- data/vendor/faiss/c_api/IndexShards_c.h +0 -39
- data/vendor/faiss/c_api/Index_c.cpp +0 -105
- data/vendor/faiss/c_api/Index_c.h +0 -183
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
- data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
- data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
- data/vendor/faiss/c_api/clone_index_c.h +0 -32
- data/vendor/faiss/c_api/error_c.h +0 -42
- data/vendor/faiss/c_api/error_impl.cpp +0 -27
- data/vendor/faiss/c_api/error_impl.h +0 -16
- data/vendor/faiss/c_api/faiss_c.h +0 -58
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
- data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
- data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
- data/vendor/faiss/c_api/index_factory_c.h +0 -30
- data/vendor/faiss/c_api/index_io_c.cpp +0 -42
- data/vendor/faiss/c_api/index_io_c.h +0 -50
- data/vendor/faiss/c_api/macros_impl.h +0 -110
- data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
- data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
- data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
- data/vendor/faiss/misc/test_blas.cpp +0 -87
- data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
- data/vendor/faiss/tests/test_merge.cpp +0 -260
- data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
- data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
- data/vendor/faiss/tests/test_params_override.cpp +0 -236
- data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
- data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
- data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
- data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -156,6 +156,14 @@ void pairwise_indexed_inner_product (
|
|
156
156
|
// threshold on nx above which we switch to BLAS to compute distances
|
157
157
|
FAISS_API extern int distance_compute_blas_threshold;
|
158
158
|
|
159
|
+
// block sizes for BLAS distance computations
|
160
|
+
FAISS_API extern int distance_compute_blas_query_bs;
|
161
|
+
FAISS_API extern int distance_compute_blas_database_bs;
|
162
|
+
|
163
|
+
// above this number of results we switch to a reservoir to collect results
|
164
|
+
// rather than a heap
|
165
|
+
FAISS_API extern int distance_compute_min_k_reservoir;
|
166
|
+
|
159
167
|
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
160
168
|
* vector y, w.r.t to max inner product
|
161
169
|
*
|
@@ -169,27 +177,17 @@ void knn_inner_product (
|
|
169
177
|
size_t d, size_t nx, size_t ny,
|
170
178
|
float_minheap_array_t * res);
|
171
179
|
|
172
|
-
/** Same as knn_inner_product, for the L2 distance
|
180
|
+
/** Same as knn_inner_product, for the L2 distance
|
181
|
+
* @param y_norm2 norms for the y vectors (nullptr or size ny)
|
182
|
+
*/
|
173
183
|
void knn_L2sqr (
|
174
184
|
const float * x,
|
175
185
|
const float * y,
|
176
186
|
size_t d, size_t nx, size_t ny,
|
177
|
-
float_maxheap_array_t * res
|
187
|
+
float_maxheap_array_t * res,
|
188
|
+
const float *y_norm2 = nullptr);
|
178
189
|
|
179
190
|
|
180
|
-
|
181
|
-
/** same as knn_L2sqr, but base_shift[bno] is subtracted to all
|
182
|
-
* computed distances.
|
183
|
-
*
|
184
|
-
* @param base_shift size ny
|
185
|
-
*/
|
186
|
-
void knn_L2sqr_base_shift (
|
187
|
-
const float * x,
|
188
|
-
const float * y,
|
189
|
-
size_t d, size_t nx, size_t ny,
|
190
|
-
float_maxheap_array_t * res,
|
191
|
-
const float *base_shift);
|
192
|
-
|
193
191
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
194
192
|
* indexed by ids. May be useful for re-ranking a pre-selected vector list
|
195
193
|
*/
|
@@ -200,11 +198,12 @@ void knn_inner_products_by_idx (
|
|
200
198
|
size_t d, size_t nx, size_t ny,
|
201
199
|
float_minheap_array_t * res);
|
202
200
|
|
203
|
-
void knn_L2sqr_by_idx (
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
201
|
+
void knn_L2sqr_by_idx (
|
202
|
+
const float * x,
|
203
|
+
const float * y,
|
204
|
+
const int64_t * ids,
|
205
|
+
size_t d, size_t nx, size_t ny,
|
206
|
+
float_maxheap_array_t * res);
|
208
207
|
|
209
208
|
/***************************************************************************
|
210
209
|
* Range search
|
@@ -239,6 +238,15 @@ void range_search_inner_product (
|
|
239
238
|
RangeSearchResult *result);
|
240
239
|
|
241
240
|
|
241
|
+
/***************************************************************************
|
242
|
+
* PQ tables computations
|
243
|
+
***************************************************************************/
|
242
244
|
|
245
|
+
/// specialized function for PQ2
|
246
|
+
void compute_PQ_dis_tables_dsub2(
|
247
|
+
size_t d, size_t ksub, const float *centroids,
|
248
|
+
size_t nx, const float * x,
|
249
|
+
bool is_inner_product,
|
250
|
+
float * dis_tables);
|
243
251
|
|
244
252
|
} // namespace faiss
|
@@ -14,6 +14,9 @@
|
|
14
14
|
#include <cstring>
|
15
15
|
#include <cmath>
|
16
16
|
|
17
|
+
#include <faiss/utils/simdlib.h>
|
18
|
+
#include <faiss/impl/FaissAssert.h>
|
19
|
+
|
17
20
|
#ifdef __SSE3__
|
18
21
|
#include <immintrin.h>
|
19
22
|
#endif
|
@@ -127,6 +130,29 @@ void fvec_L2sqr_ny_ref (float * dis,
|
|
127
130
|
}
|
128
131
|
|
129
132
|
|
133
|
+
void fvec_inner_products_ny_ref (float * ip,
|
134
|
+
const float * x,
|
135
|
+
const float * y,
|
136
|
+
size_t d, size_t ny)
|
137
|
+
{
|
138
|
+
// BLAS slower for the use cases here
|
139
|
+
#if 0
|
140
|
+
{
|
141
|
+
FINTEGER di = d;
|
142
|
+
FINTEGER nyi = ny;
|
143
|
+
float one = 1.0, zero = 0.0;
|
144
|
+
FINTEGER onei = 1;
|
145
|
+
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
146
|
+
}
|
147
|
+
#endif
|
148
|
+
for (size_t i = 0; i < ny; i++) {
|
149
|
+
ip[i] = fvec_inner_product (x, y, d);
|
150
|
+
y += d;
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
|
155
|
+
|
130
156
|
|
131
157
|
|
132
158
|
/*********************************************************
|
@@ -174,12 +200,39 @@ float fvec_norm_L2sqr (const float * x,
|
|
174
200
|
|
175
201
|
namespace {
|
176
202
|
|
177
|
-
|
178
|
-
|
179
|
-
|
203
|
+
/// Function that does a component-wise operation between x and y
|
204
|
+
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
205
|
+
/// functions below
|
206
|
+
struct ElementOpL2 {
|
207
|
+
|
208
|
+
static float op (float x, float y) {
|
209
|
+
float tmp = x - y;
|
210
|
+
return tmp * tmp;
|
211
|
+
}
|
212
|
+
|
213
|
+
static __m128 op (__m128 x, __m128 y) {
|
214
|
+
__m128 tmp = x - y;
|
215
|
+
return tmp * tmp;
|
216
|
+
}
|
217
|
+
|
218
|
+
};
|
180
219
|
|
220
|
+
/// Function that does a component-wise operation between x and y
|
221
|
+
/// to compute inner products
|
222
|
+
struct ElementOpIP {
|
181
223
|
|
182
|
-
|
224
|
+
static float op (float x, float y) {
|
225
|
+
return x * y;
|
226
|
+
}
|
227
|
+
|
228
|
+
static __m128 op (__m128 x, __m128 y) {
|
229
|
+
return x * y;
|
230
|
+
}
|
231
|
+
|
232
|
+
};
|
233
|
+
|
234
|
+
template<class ElementOp>
|
235
|
+
void fvec_op_ny_D1 (float * dis, const float * x,
|
183
236
|
const float * y, size_t ny)
|
184
237
|
{
|
185
238
|
float x0s = x[0];
|
@@ -187,11 +240,9 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
|
|
187
240
|
|
188
241
|
size_t i;
|
189
242
|
for (i = 0; i + 3 < ny; i += 4) {
|
190
|
-
__m128
|
191
|
-
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
192
|
-
accu = tmp * tmp;
|
243
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
193
244
|
dis[i] = _mm_cvtss_f32 (accu);
|
194
|
-
tmp = _mm_shuffle_ps (accu, accu, 1);
|
245
|
+
__m128 tmp = _mm_shuffle_ps (accu, accu, 1);
|
195
246
|
dis[i + 1] = _mm_cvtss_f32 (tmp);
|
196
247
|
tmp = _mm_shuffle_ps (accu, accu, 2);
|
197
248
|
dis[i + 2] = _mm_cvtss_f32 (tmp);
|
@@ -199,69 +250,63 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
|
|
199
250
|
dis[i + 3] = _mm_cvtss_f32 (tmp);
|
200
251
|
}
|
201
252
|
while (i < ny) { // handle non-multiple-of-4 case
|
202
|
-
dis[i++] =
|
253
|
+
dis[i++] = ElementOp::op(x0s, *y++);
|
203
254
|
}
|
204
255
|
}
|
205
256
|
|
206
|
-
|
207
|
-
void
|
257
|
+
template<class ElementOp>
|
258
|
+
void fvec_op_ny_D2 (float * dis, const float * x,
|
208
259
|
const float * y, size_t ny)
|
209
260
|
{
|
210
261
|
__m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
|
211
262
|
|
212
263
|
size_t i;
|
213
264
|
for (i = 0; i + 1 < ny; i += 2) {
|
214
|
-
__m128
|
215
|
-
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
216
|
-
accu = tmp * tmp;
|
265
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
217
266
|
accu = _mm_hadd_ps (accu, accu);
|
218
267
|
dis[i] = _mm_cvtss_f32 (accu);
|
219
268
|
accu = _mm_shuffle_ps (accu, accu, 3);
|
220
269
|
dis[i + 1] = _mm_cvtss_f32 (accu);
|
221
270
|
}
|
222
271
|
if (i < ny) { // handle odd case
|
223
|
-
dis[i] =
|
272
|
+
dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
|
224
273
|
}
|
225
274
|
}
|
226
275
|
|
227
276
|
|
228
277
|
|
229
|
-
|
278
|
+
template<class ElementOp>
|
279
|
+
void fvec_op_ny_D4 (float * dis, const float * x,
|
230
280
|
const float * y, size_t ny)
|
231
281
|
{
|
232
282
|
__m128 x0 = _mm_loadu_ps(x);
|
233
283
|
|
234
284
|
for (size_t i = 0; i < ny; i++) {
|
235
|
-
__m128
|
236
|
-
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
237
|
-
accu = tmp * tmp;
|
285
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
238
286
|
accu = _mm_hadd_ps (accu, accu);
|
239
287
|
accu = _mm_hadd_ps (accu, accu);
|
240
288
|
dis[i] = _mm_cvtss_f32 (accu);
|
241
289
|
}
|
242
290
|
}
|
243
291
|
|
244
|
-
|
245
|
-
void
|
292
|
+
template<class ElementOp>
|
293
|
+
void fvec_op_ny_D8 (float * dis, const float * x,
|
246
294
|
const float * y, size_t ny)
|
247
295
|
{
|
248
296
|
__m128 x0 = _mm_loadu_ps(x);
|
249
297
|
__m128 x1 = _mm_loadu_ps(x + 4);
|
250
298
|
|
251
299
|
for (size_t i = 0; i < ny; i++) {
|
252
|
-
__m128
|
253
|
-
|
254
|
-
accu = tmp * tmp;
|
255
|
-
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
256
|
-
accu += tmp * tmp;
|
300
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
301
|
+
accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
|
257
302
|
accu = _mm_hadd_ps (accu, accu);
|
258
303
|
accu = _mm_hadd_ps (accu, accu);
|
259
304
|
dis[i] = _mm_cvtss_f32 (accu);
|
260
305
|
}
|
261
306
|
}
|
262
307
|
|
263
|
-
|
264
|
-
void
|
308
|
+
template<class ElementOp>
|
309
|
+
void fvec_op_ny_D12 (float * dis, const float * x,
|
265
310
|
const float * y, size_t ny)
|
266
311
|
{
|
267
312
|
__m128 x0 = _mm_loadu_ps(x);
|
@@ -269,13 +314,9 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
|
|
269
314
|
__m128 x2 = _mm_loadu_ps(x + 8);
|
270
315
|
|
271
316
|
for (size_t i = 0; i < ny; i++) {
|
272
|
-
__m128
|
273
|
-
|
274
|
-
accu
|
275
|
-
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
276
|
-
accu += tmp * tmp;
|
277
|
-
tmp = x2 - _mm_loadu_ps (y); y += 4;
|
278
|
-
accu += tmp * tmp;
|
317
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
318
|
+
accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
|
319
|
+
accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
|
279
320
|
accu = _mm_hadd_ps (accu, accu);
|
280
321
|
accu = _mm_hadd_ps (accu, accu);
|
281
322
|
dis[i] = _mm_cvtss_f32 (accu);
|
@@ -283,31 +324,52 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
|
|
283
324
|
}
|
284
325
|
|
285
326
|
|
327
|
+
|
286
328
|
} // anonymous namespace
|
287
329
|
|
288
330
|
void fvec_L2sqr_ny (float * dis, const float * x,
|
289
331
|
const float * y, size_t d, size_t ny) {
|
290
332
|
// optimized for a few special cases
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
case 2:
|
296
|
-
fvec_L2sqr_ny_D2 (dis, x, y, ny);
|
297
|
-
return;
|
298
|
-
case 4:
|
299
|
-
fvec_L2sqr_ny_D4 (dis, x, y, ny);
|
333
|
+
|
334
|
+
#define DISPATCH(dval) \
|
335
|
+
case dval:\
|
336
|
+
fvec_op_ny_D ## dval <ElementOpL2> (dis, x, y, ny); \
|
300
337
|
return;
|
301
|
-
|
302
|
-
|
338
|
+
|
339
|
+
switch(d) {
|
340
|
+
DISPATCH(1)
|
341
|
+
DISPATCH(2)
|
342
|
+
DISPATCH(4)
|
343
|
+
DISPATCH(8)
|
344
|
+
DISPATCH(12)
|
345
|
+
default:
|
346
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
303
347
|
return;
|
304
|
-
|
305
|
-
|
348
|
+
}
|
349
|
+
#undef DISPATCH
|
350
|
+
|
351
|
+
}
|
352
|
+
|
353
|
+
void fvec_inner_products_ny (float * dis, const float * x,
|
354
|
+
const float * y, size_t d, size_t ny) {
|
355
|
+
|
356
|
+
#define DISPATCH(dval) \
|
357
|
+
case dval:\
|
358
|
+
fvec_op_ny_D ## dval <ElementOpIP> (dis, x, y, ny); \
|
306
359
|
return;
|
360
|
+
|
361
|
+
switch(d) {
|
362
|
+
DISPATCH(1)
|
363
|
+
DISPATCH(2)
|
364
|
+
DISPATCH(4)
|
365
|
+
DISPATCH(8)
|
366
|
+
DISPATCH(12)
|
307
367
|
default:
|
308
|
-
|
368
|
+
fvec_inner_products_ny_ref (dis, x, y, d, ny);
|
309
369
|
return;
|
310
370
|
}
|
371
|
+
#undef DISPATCH
|
372
|
+
|
311
373
|
}
|
312
374
|
|
313
375
|
|
@@ -644,6 +706,11 @@ void fvec_L2sqr_ny (float * dis, const float * x,
|
|
644
706
|
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
645
707
|
}
|
646
708
|
|
709
|
+
void fvec_inner_products_ny (float * dis, const float * x,
|
710
|
+
const float * y, size_t d, size_t ny) {
|
711
|
+
fvec_inner_products_ny_ref (dis, x, y, d, ny);
|
712
|
+
}
|
713
|
+
|
647
714
|
|
648
715
|
#endif
|
649
716
|
|
@@ -803,6 +870,167 @@ int fvec_madd_and_argmin (size_t n, const float *a,
|
|
803
870
|
#endif
|
804
871
|
|
805
872
|
|
873
|
+
/***************************************************************************
|
874
|
+
* PQ tables computations
|
875
|
+
***************************************************************************/
|
876
|
+
|
877
|
+
#ifdef __AVX2__
|
878
|
+
|
879
|
+
namespace {
|
880
|
+
|
881
|
+
|
882
|
+
// get even float32's of a and b, interleaved
|
883
|
+
simd8float32 geteven(simd8float32 a, simd8float32 b) {
|
884
|
+
return simd8float32(
|
885
|
+
_mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
|
886
|
+
);
|
887
|
+
}
|
888
|
+
|
889
|
+
// get odd float32's of a and b, interleaved
|
890
|
+
simd8float32 getodd(simd8float32 a, simd8float32 b) {
|
891
|
+
return simd8float32(
|
892
|
+
_mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
|
893
|
+
);
|
894
|
+
}
|
895
|
+
|
896
|
+
// 3 cycles
|
897
|
+
// if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
|
898
|
+
simd8float32 getlow128(simd8float32 a, simd8float32 b) {
|
899
|
+
return simd8float32(
|
900
|
+
_mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
|
901
|
+
);
|
902
|
+
}
|
903
|
+
|
904
|
+
simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
|
905
|
+
return simd8float32(
|
906
|
+
_mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
|
907
|
+
);
|
908
|
+
}
|
909
|
+
|
910
|
+
/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
|
911
|
+
template<bool is_inner_product>
|
912
|
+
void pq2_8cents_table(
|
913
|
+
const simd8float32 centroids[8],
|
914
|
+
const simd8float32 x,
|
915
|
+
float *out, size_t ldo, size_t nout = 4
|
916
|
+
) {
|
917
|
+
|
918
|
+
simd8float32 ips[4];
|
919
|
+
|
920
|
+
for(int i = 0; i < 4; i++) {
|
921
|
+
simd8float32 p1, p2;
|
922
|
+
if (is_inner_product) {
|
923
|
+
p1 = x * centroids[2 * i];
|
924
|
+
p2 = x * centroids[2 * i + 1];
|
925
|
+
} else {
|
926
|
+
p1 = (x - centroids[2 * i]);
|
927
|
+
p1 = p1 * p1;
|
928
|
+
p2 = (x - centroids[2 * i + 1]);
|
929
|
+
p2 = p2 * p2;
|
930
|
+
}
|
931
|
+
ips[i] = hadd(p1, p2);
|
932
|
+
}
|
933
|
+
|
934
|
+
simd8float32 ip02a = geteven(ips[0], ips[1]);
|
935
|
+
simd8float32 ip02b = geteven(ips[2], ips[3]);
|
936
|
+
simd8float32 ip0 = getlow128(ip02a, ip02b);
|
937
|
+
simd8float32 ip2 = gethigh128(ip02a, ip02b);
|
938
|
+
|
939
|
+
simd8float32 ip13a = getodd(ips[0], ips[1]);
|
940
|
+
simd8float32 ip13b = getodd(ips[2], ips[3]);
|
941
|
+
simd8float32 ip1 = getlow128(ip13a, ip13b);
|
942
|
+
simd8float32 ip3 = gethigh128(ip13a, ip13b);
|
943
|
+
|
944
|
+
switch(nout) {
|
945
|
+
case 4:
|
946
|
+
ip3.storeu(out + 3 * ldo);
|
947
|
+
case 3:
|
948
|
+
ip2.storeu(out + 2 * ldo);
|
949
|
+
case 2:
|
950
|
+
ip1.storeu(out + 1 * ldo);
|
951
|
+
case 1:
|
952
|
+
ip0.storeu(out);
|
953
|
+
}
|
954
|
+
}
|
955
|
+
|
956
|
+
simd8float32 load_simd8float32_partial(const float *x, int n) {
|
957
|
+
ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
958
|
+
float *wp = tmp;
|
959
|
+
for (int i = 0; i < n; i++) {
|
960
|
+
*wp++ = *x++;
|
961
|
+
}
|
962
|
+
return simd8float32(tmp);
|
963
|
+
}
|
964
|
+
|
965
|
+
} // anonymous namespace
|
966
|
+
|
967
|
+
|
968
|
+
|
969
|
+
|
970
|
+
void compute_PQ_dis_tables_dsub2(
|
971
|
+
size_t d, size_t ksub, const float *all_centroids,
|
972
|
+
size_t nx, const float * x,
|
973
|
+
bool is_inner_product,
|
974
|
+
float * dis_tables)
|
975
|
+
{
|
976
|
+
size_t M = d / 2;
|
977
|
+
FAISS_THROW_IF_NOT(ksub % 8 == 0);
|
978
|
+
|
979
|
+
for(size_t m0 = 0; m0 < M; m0 += 4) {
|
980
|
+
int m1 = std::min(M, m0 + 4);
|
981
|
+
for(int k0 = 0; k0 < ksub; k0 += 8) {
|
982
|
+
|
983
|
+
simd8float32 centroids[8];
|
984
|
+
for (int k = 0; k < 8; k++) {
|
985
|
+
float centroid[8] __attribute__((aligned(32)));
|
986
|
+
size_t wp = 0;
|
987
|
+
size_t rp = (m0 * ksub + k + k0) * 2;
|
988
|
+
for (int m = m0; m < m1; m++) {
|
989
|
+
centroid[wp++] = all_centroids[rp];
|
990
|
+
centroid[wp++] = all_centroids[rp + 1];
|
991
|
+
rp += 2 * ksub;
|
992
|
+
}
|
993
|
+
centroids[k] = simd8float32(centroid);
|
994
|
+
}
|
995
|
+
for(size_t i = 0; i < nx; i++) {
|
996
|
+
simd8float32 xi;
|
997
|
+
if (m1 == m0 + 4) {
|
998
|
+
xi.loadu(x + i * d + m0 * 2);
|
999
|
+
} else {
|
1000
|
+
xi = load_simd8float32_partial(x + i * d + m0 * 2, 2 * (m1 - m0));
|
1001
|
+
}
|
1002
|
+
|
1003
|
+
if(is_inner_product) {
|
1004
|
+
pq2_8cents_table<true>(
|
1005
|
+
centroids, xi,
|
1006
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
1007
|
+
ksub, m1 - m0
|
1008
|
+
);
|
1009
|
+
} else {
|
1010
|
+
pq2_8cents_table<false>(
|
1011
|
+
centroids, xi,
|
1012
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
1013
|
+
ksub, m1 - m0
|
1014
|
+
);
|
1015
|
+
}
|
1016
|
+
}
|
1017
|
+
}
|
1018
|
+
}
|
1019
|
+
|
1020
|
+
}
|
1021
|
+
|
1022
|
+
#else
|
1023
|
+
|
1024
|
+
void compute_PQ_dis_tables_dsub2(
|
1025
|
+
size_t d, size_t ksub, const float *all_centroids,
|
1026
|
+
size_t nx, const float * x,
|
1027
|
+
bool is_inner_product,
|
1028
|
+
float * dis_tables)
|
1029
|
+
{
|
1030
|
+
FAISS_THROW_MSG("only implemented for AVX2");
|
1031
|
+
}
|
1032
|
+
|
1033
|
+
#endif
|
806
1034
|
|
807
1035
|
|
808
1036
|
} // namespace faiss
|