faiss 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +36 -33
- data/vendor/faiss/faiss/AutoTune.h +6 -3
- data/vendor/faiss/faiss/Clustering.cpp +16 -12
- data/vendor/faiss/faiss/Index.cpp +3 -4
- data/vendor/faiss/faiss/Index.h +3 -3
- data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
- data/vendor/faiss/faiss/IndexBinary.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
- data/vendor/faiss/faiss/IndexFlat.h +0 -51
- data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
- data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
- data/vendor/faiss/faiss/IndexIVF.h +22 -15
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
- data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
- data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
- data/vendor/faiss/faiss/IndexRefine.h +73 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
- data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
- data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
- data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
- data/vendor/faiss/faiss/impl/io.cpp +33 -2
- data/vendor/faiss/faiss/impl/io.h +7 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
- data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
- data/vendor/faiss/faiss/index_factory.cpp +112 -7
- data/vendor/faiss/faiss/index_io.h +1 -48
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
- data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
- data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
- data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
- data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
- data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
- data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
- data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
- data/vendor/faiss/faiss/utils/Heap.h +61 -50
- data/vendor/faiss/faiss/utils/distances.cpp +164 -319
- data/vendor/faiss/faiss/utils/distances.h +28 -20
- data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
- data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
- data/vendor/faiss/faiss/utils/hamming.h +2 -7
- data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
- data/vendor/faiss/faiss/utils/partitioning.h +69 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
- data/vendor/faiss/faiss/utils/simdlib.h +31 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
- metadata +43 -141
- data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
- data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
- data/vendor/faiss/c_api/AutoTune_c.h +0 -66
- data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
- data/vendor/faiss/c_api/Clustering_c.h +0 -123
- data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
- data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
- data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
- data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
- data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
- data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
- data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
- data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
- data/vendor/faiss/c_api/IndexShards_c.h +0 -39
- data/vendor/faiss/c_api/Index_c.cpp +0 -105
- data/vendor/faiss/c_api/Index_c.h +0 -183
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
- data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
- data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
- data/vendor/faiss/c_api/clone_index_c.h +0 -32
- data/vendor/faiss/c_api/error_c.h +0 -42
- data/vendor/faiss/c_api/error_impl.cpp +0 -27
- data/vendor/faiss/c_api/error_impl.h +0 -16
- data/vendor/faiss/c_api/faiss_c.h +0 -58
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
- data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
- data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
- data/vendor/faiss/c_api/index_factory_c.h +0 -30
- data/vendor/faiss/c_api/index_io_c.cpp +0 -42
- data/vendor/faiss/c_api/index_io_c.h +0 -50
- data/vendor/faiss/c_api/macros_impl.h +0 -110
- data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
- data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
- data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
- data/vendor/faiss/misc/test_blas.cpp +0 -87
- data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
- data/vendor/faiss/tests/test_merge.cpp +0 -260
- data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
- data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
- data/vendor/faiss/tests/test_params_override.cpp +0 -236
- data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
- data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
- data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
- data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -156,6 +156,14 @@ void pairwise_indexed_inner_product (
|
|
156
156
|
// threshold on nx above which we switch to BLAS to compute distances
|
157
157
|
FAISS_API extern int distance_compute_blas_threshold;
|
158
158
|
|
159
|
+
// block sizes for BLAS distance computations
|
160
|
+
FAISS_API extern int distance_compute_blas_query_bs;
|
161
|
+
FAISS_API extern int distance_compute_blas_database_bs;
|
162
|
+
|
163
|
+
// above this number of results we switch to a reservoir to collect results
|
164
|
+
// rather than a heap
|
165
|
+
FAISS_API extern int distance_compute_min_k_reservoir;
|
166
|
+
|
159
167
|
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
160
168
|
* vector y, w.r.t to max inner product
|
161
169
|
*
|
@@ -169,27 +177,17 @@ void knn_inner_product (
|
|
169
177
|
size_t d, size_t nx, size_t ny,
|
170
178
|
float_minheap_array_t * res);
|
171
179
|
|
172
|
-
/** Same as knn_inner_product, for the L2 distance
|
180
|
+
/** Same as knn_inner_product, for the L2 distance
|
181
|
+
* @param y_norm2 norms for the y vectors (nullptr or size ny)
|
182
|
+
*/
|
173
183
|
void knn_L2sqr (
|
174
184
|
const float * x,
|
175
185
|
const float * y,
|
176
186
|
size_t d, size_t nx, size_t ny,
|
177
|
-
float_maxheap_array_t * res
|
187
|
+
float_maxheap_array_t * res,
|
188
|
+
const float *y_norm2 = nullptr);
|
178
189
|
|
179
190
|
|
180
|
-
|
181
|
-
/** same as knn_L2sqr, but base_shift[bno] is subtracted to all
|
182
|
-
* computed distances.
|
183
|
-
*
|
184
|
-
* @param base_shift size ny
|
185
|
-
*/
|
186
|
-
void knn_L2sqr_base_shift (
|
187
|
-
const float * x,
|
188
|
-
const float * y,
|
189
|
-
size_t d, size_t nx, size_t ny,
|
190
|
-
float_maxheap_array_t * res,
|
191
|
-
const float *base_shift);
|
192
|
-
|
193
191
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
194
192
|
* indexed by ids. May be useful for re-ranking a pre-selected vector list
|
195
193
|
*/
|
@@ -200,11 +198,12 @@ void knn_inner_products_by_idx (
|
|
200
198
|
size_t d, size_t nx, size_t ny,
|
201
199
|
float_minheap_array_t * res);
|
202
200
|
|
203
|
-
void knn_L2sqr_by_idx (
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
201
|
+
void knn_L2sqr_by_idx (
|
202
|
+
const float * x,
|
203
|
+
const float * y,
|
204
|
+
const int64_t * ids,
|
205
|
+
size_t d, size_t nx, size_t ny,
|
206
|
+
float_maxheap_array_t * res);
|
208
207
|
|
209
208
|
/***************************************************************************
|
210
209
|
* Range search
|
@@ -239,6 +238,15 @@ void range_search_inner_product (
|
|
239
238
|
RangeSearchResult *result);
|
240
239
|
|
241
240
|
|
241
|
+
/***************************************************************************
|
242
|
+
* PQ tables computations
|
243
|
+
***************************************************************************/
|
242
244
|
|
245
|
+
/// specialized function for PQ2
|
246
|
+
void compute_PQ_dis_tables_dsub2(
|
247
|
+
size_t d, size_t ksub, const float *centroids,
|
248
|
+
size_t nx, const float * x,
|
249
|
+
bool is_inner_product,
|
250
|
+
float * dis_tables);
|
243
251
|
|
244
252
|
} // namespace faiss
|
@@ -14,6 +14,9 @@
|
|
14
14
|
#include <cstring>
|
15
15
|
#include <cmath>
|
16
16
|
|
17
|
+
#include <faiss/utils/simdlib.h>
|
18
|
+
#include <faiss/impl/FaissAssert.h>
|
19
|
+
|
17
20
|
#ifdef __SSE3__
|
18
21
|
#include <immintrin.h>
|
19
22
|
#endif
|
@@ -127,6 +130,29 @@ void fvec_L2sqr_ny_ref (float * dis,
|
|
127
130
|
}
|
128
131
|
|
129
132
|
|
133
|
+
void fvec_inner_products_ny_ref (float * ip,
|
134
|
+
const float * x,
|
135
|
+
const float * y,
|
136
|
+
size_t d, size_t ny)
|
137
|
+
{
|
138
|
+
// BLAS slower for the use cases here
|
139
|
+
#if 0
|
140
|
+
{
|
141
|
+
FINTEGER di = d;
|
142
|
+
FINTEGER nyi = ny;
|
143
|
+
float one = 1.0, zero = 0.0;
|
144
|
+
FINTEGER onei = 1;
|
145
|
+
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
146
|
+
}
|
147
|
+
#endif
|
148
|
+
for (size_t i = 0; i < ny; i++) {
|
149
|
+
ip[i] = fvec_inner_product (x, y, d);
|
150
|
+
y += d;
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
|
155
|
+
|
130
156
|
|
131
157
|
|
132
158
|
/*********************************************************
|
@@ -174,12 +200,39 @@ float fvec_norm_L2sqr (const float * x,
|
|
174
200
|
|
175
201
|
namespace {
|
176
202
|
|
177
|
-
|
178
|
-
|
179
|
-
|
203
|
+
/// Function that does a component-wise operation between x and y
|
204
|
+
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
205
|
+
/// functions below
|
206
|
+
struct ElementOpL2 {
|
207
|
+
|
208
|
+
static float op (float x, float y) {
|
209
|
+
float tmp = x - y;
|
210
|
+
return tmp * tmp;
|
211
|
+
}
|
212
|
+
|
213
|
+
static __m128 op (__m128 x, __m128 y) {
|
214
|
+
__m128 tmp = x - y;
|
215
|
+
return tmp * tmp;
|
216
|
+
}
|
217
|
+
|
218
|
+
};
|
180
219
|
|
220
|
+
/// Function that does a component-wise operation between x and y
|
221
|
+
/// to compute inner products
|
222
|
+
struct ElementOpIP {
|
181
223
|
|
182
|
-
|
224
|
+
static float op (float x, float y) {
|
225
|
+
return x * y;
|
226
|
+
}
|
227
|
+
|
228
|
+
static __m128 op (__m128 x, __m128 y) {
|
229
|
+
return x * y;
|
230
|
+
}
|
231
|
+
|
232
|
+
};
|
233
|
+
|
234
|
+
template<class ElementOp>
|
235
|
+
void fvec_op_ny_D1 (float * dis, const float * x,
|
183
236
|
const float * y, size_t ny)
|
184
237
|
{
|
185
238
|
float x0s = x[0];
|
@@ -187,11 +240,9 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
|
|
187
240
|
|
188
241
|
size_t i;
|
189
242
|
for (i = 0; i + 3 < ny; i += 4) {
|
190
|
-
__m128
|
191
|
-
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
192
|
-
accu = tmp * tmp;
|
243
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
193
244
|
dis[i] = _mm_cvtss_f32 (accu);
|
194
|
-
tmp = _mm_shuffle_ps (accu, accu, 1);
|
245
|
+
__m128 tmp = _mm_shuffle_ps (accu, accu, 1);
|
195
246
|
dis[i + 1] = _mm_cvtss_f32 (tmp);
|
196
247
|
tmp = _mm_shuffle_ps (accu, accu, 2);
|
197
248
|
dis[i + 2] = _mm_cvtss_f32 (tmp);
|
@@ -199,69 +250,63 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
|
|
199
250
|
dis[i + 3] = _mm_cvtss_f32 (tmp);
|
200
251
|
}
|
201
252
|
while (i < ny) { // handle non-multiple-of-4 case
|
202
|
-
dis[i++] =
|
253
|
+
dis[i++] = ElementOp::op(x0s, *y++);
|
203
254
|
}
|
204
255
|
}
|
205
256
|
|
206
|
-
|
207
|
-
void
|
257
|
+
template<class ElementOp>
|
258
|
+
void fvec_op_ny_D2 (float * dis, const float * x,
|
208
259
|
const float * y, size_t ny)
|
209
260
|
{
|
210
261
|
__m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
|
211
262
|
|
212
263
|
size_t i;
|
213
264
|
for (i = 0; i + 1 < ny; i += 2) {
|
214
|
-
__m128
|
215
|
-
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
216
|
-
accu = tmp * tmp;
|
265
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
217
266
|
accu = _mm_hadd_ps (accu, accu);
|
218
267
|
dis[i] = _mm_cvtss_f32 (accu);
|
219
268
|
accu = _mm_shuffle_ps (accu, accu, 3);
|
220
269
|
dis[i + 1] = _mm_cvtss_f32 (accu);
|
221
270
|
}
|
222
271
|
if (i < ny) { // handle odd case
|
223
|
-
dis[i] =
|
272
|
+
dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
|
224
273
|
}
|
225
274
|
}
|
226
275
|
|
227
276
|
|
228
277
|
|
229
|
-
|
278
|
+
template<class ElementOp>
|
279
|
+
void fvec_op_ny_D4 (float * dis, const float * x,
|
230
280
|
const float * y, size_t ny)
|
231
281
|
{
|
232
282
|
__m128 x0 = _mm_loadu_ps(x);
|
233
283
|
|
234
284
|
for (size_t i = 0; i < ny; i++) {
|
235
|
-
__m128
|
236
|
-
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
237
|
-
accu = tmp * tmp;
|
285
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
238
286
|
accu = _mm_hadd_ps (accu, accu);
|
239
287
|
accu = _mm_hadd_ps (accu, accu);
|
240
288
|
dis[i] = _mm_cvtss_f32 (accu);
|
241
289
|
}
|
242
290
|
}
|
243
291
|
|
244
|
-
|
245
|
-
void
|
292
|
+
template<class ElementOp>
|
293
|
+
void fvec_op_ny_D8 (float * dis, const float * x,
|
246
294
|
const float * y, size_t ny)
|
247
295
|
{
|
248
296
|
__m128 x0 = _mm_loadu_ps(x);
|
249
297
|
__m128 x1 = _mm_loadu_ps(x + 4);
|
250
298
|
|
251
299
|
for (size_t i = 0; i < ny; i++) {
|
252
|
-
__m128
|
253
|
-
|
254
|
-
accu = tmp * tmp;
|
255
|
-
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
256
|
-
accu += tmp * tmp;
|
300
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
301
|
+
accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
|
257
302
|
accu = _mm_hadd_ps (accu, accu);
|
258
303
|
accu = _mm_hadd_ps (accu, accu);
|
259
304
|
dis[i] = _mm_cvtss_f32 (accu);
|
260
305
|
}
|
261
306
|
}
|
262
307
|
|
263
|
-
|
264
|
-
void
|
308
|
+
template<class ElementOp>
|
309
|
+
void fvec_op_ny_D12 (float * dis, const float * x,
|
265
310
|
const float * y, size_t ny)
|
266
311
|
{
|
267
312
|
__m128 x0 = _mm_loadu_ps(x);
|
@@ -269,13 +314,9 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
|
|
269
314
|
__m128 x2 = _mm_loadu_ps(x + 8);
|
270
315
|
|
271
316
|
for (size_t i = 0; i < ny; i++) {
|
272
|
-
__m128
|
273
|
-
|
274
|
-
accu
|
275
|
-
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
276
|
-
accu += tmp * tmp;
|
277
|
-
tmp = x2 - _mm_loadu_ps (y); y += 4;
|
278
|
-
accu += tmp * tmp;
|
317
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
|
318
|
+
accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
|
319
|
+
accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
|
279
320
|
accu = _mm_hadd_ps (accu, accu);
|
280
321
|
accu = _mm_hadd_ps (accu, accu);
|
281
322
|
dis[i] = _mm_cvtss_f32 (accu);
|
@@ -283,31 +324,52 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
|
|
283
324
|
}
|
284
325
|
|
285
326
|
|
327
|
+
|
286
328
|
} // anonymous namespace
|
287
329
|
|
288
330
|
void fvec_L2sqr_ny (float * dis, const float * x,
|
289
331
|
const float * y, size_t d, size_t ny) {
|
290
332
|
// optimized for a few special cases
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
case 2:
|
296
|
-
fvec_L2sqr_ny_D2 (dis, x, y, ny);
|
297
|
-
return;
|
298
|
-
case 4:
|
299
|
-
fvec_L2sqr_ny_D4 (dis, x, y, ny);
|
333
|
+
|
334
|
+
#define DISPATCH(dval) \
|
335
|
+
case dval:\
|
336
|
+
fvec_op_ny_D ## dval <ElementOpL2> (dis, x, y, ny); \
|
300
337
|
return;
|
301
|
-
|
302
|
-
|
338
|
+
|
339
|
+
switch(d) {
|
340
|
+
DISPATCH(1)
|
341
|
+
DISPATCH(2)
|
342
|
+
DISPATCH(4)
|
343
|
+
DISPATCH(8)
|
344
|
+
DISPATCH(12)
|
345
|
+
default:
|
346
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
303
347
|
return;
|
304
|
-
|
305
|
-
|
348
|
+
}
|
349
|
+
#undef DISPATCH
|
350
|
+
|
351
|
+
}
|
352
|
+
|
353
|
+
void fvec_inner_products_ny (float * dis, const float * x,
|
354
|
+
const float * y, size_t d, size_t ny) {
|
355
|
+
|
356
|
+
#define DISPATCH(dval) \
|
357
|
+
case dval:\
|
358
|
+
fvec_op_ny_D ## dval <ElementOpIP> (dis, x, y, ny); \
|
306
359
|
return;
|
360
|
+
|
361
|
+
switch(d) {
|
362
|
+
DISPATCH(1)
|
363
|
+
DISPATCH(2)
|
364
|
+
DISPATCH(4)
|
365
|
+
DISPATCH(8)
|
366
|
+
DISPATCH(12)
|
307
367
|
default:
|
308
|
-
|
368
|
+
fvec_inner_products_ny_ref (dis, x, y, d, ny);
|
309
369
|
return;
|
310
370
|
}
|
371
|
+
#undef DISPATCH
|
372
|
+
|
311
373
|
}
|
312
374
|
|
313
375
|
|
@@ -644,6 +706,11 @@ void fvec_L2sqr_ny (float * dis, const float * x,
|
|
644
706
|
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
645
707
|
}
|
646
708
|
|
709
|
+
void fvec_inner_products_ny (float * dis, const float * x,
|
710
|
+
const float * y, size_t d, size_t ny) {
|
711
|
+
fvec_inner_products_ny_ref (dis, x, y, d, ny);
|
712
|
+
}
|
713
|
+
|
647
714
|
|
648
715
|
#endif
|
649
716
|
|
@@ -803,6 +870,167 @@ int fvec_madd_and_argmin (size_t n, const float *a,
|
|
803
870
|
#endif
|
804
871
|
|
805
872
|
|
873
|
+
/***************************************************************************
|
874
|
+
* PQ tables computations
|
875
|
+
***************************************************************************/
|
876
|
+
|
877
|
+
#ifdef __AVX2__
|
878
|
+
|
879
|
+
namespace {
|
880
|
+
|
881
|
+
|
882
|
+
// get even float32's of a and b, interleaved
|
883
|
+
simd8float32 geteven(simd8float32 a, simd8float32 b) {
|
884
|
+
return simd8float32(
|
885
|
+
_mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
|
886
|
+
);
|
887
|
+
}
|
888
|
+
|
889
|
+
// get odd float32's of a and b, interleaved
|
890
|
+
simd8float32 getodd(simd8float32 a, simd8float32 b) {
|
891
|
+
return simd8float32(
|
892
|
+
_mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
|
893
|
+
);
|
894
|
+
}
|
895
|
+
|
896
|
+
// 3 cycles
|
897
|
+
// if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
|
898
|
+
simd8float32 getlow128(simd8float32 a, simd8float32 b) {
|
899
|
+
return simd8float32(
|
900
|
+
_mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
|
901
|
+
);
|
902
|
+
}
|
903
|
+
|
904
|
+
simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
|
905
|
+
return simd8float32(
|
906
|
+
_mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
|
907
|
+
);
|
908
|
+
}
|
909
|
+
|
910
|
+
/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
|
911
|
+
template<bool is_inner_product>
|
912
|
+
void pq2_8cents_table(
|
913
|
+
const simd8float32 centroids[8],
|
914
|
+
const simd8float32 x,
|
915
|
+
float *out, size_t ldo, size_t nout = 4
|
916
|
+
) {
|
917
|
+
|
918
|
+
simd8float32 ips[4];
|
919
|
+
|
920
|
+
for(int i = 0; i < 4; i++) {
|
921
|
+
simd8float32 p1, p2;
|
922
|
+
if (is_inner_product) {
|
923
|
+
p1 = x * centroids[2 * i];
|
924
|
+
p2 = x * centroids[2 * i + 1];
|
925
|
+
} else {
|
926
|
+
p1 = (x - centroids[2 * i]);
|
927
|
+
p1 = p1 * p1;
|
928
|
+
p2 = (x - centroids[2 * i + 1]);
|
929
|
+
p2 = p2 * p2;
|
930
|
+
}
|
931
|
+
ips[i] = hadd(p1, p2);
|
932
|
+
}
|
933
|
+
|
934
|
+
simd8float32 ip02a = geteven(ips[0], ips[1]);
|
935
|
+
simd8float32 ip02b = geteven(ips[2], ips[3]);
|
936
|
+
simd8float32 ip0 = getlow128(ip02a, ip02b);
|
937
|
+
simd8float32 ip2 = gethigh128(ip02a, ip02b);
|
938
|
+
|
939
|
+
simd8float32 ip13a = getodd(ips[0], ips[1]);
|
940
|
+
simd8float32 ip13b = getodd(ips[2], ips[3]);
|
941
|
+
simd8float32 ip1 = getlow128(ip13a, ip13b);
|
942
|
+
simd8float32 ip3 = gethigh128(ip13a, ip13b);
|
943
|
+
|
944
|
+
switch(nout) {
|
945
|
+
case 4:
|
946
|
+
ip3.storeu(out + 3 * ldo);
|
947
|
+
case 3:
|
948
|
+
ip2.storeu(out + 2 * ldo);
|
949
|
+
case 2:
|
950
|
+
ip1.storeu(out + 1 * ldo);
|
951
|
+
case 1:
|
952
|
+
ip0.storeu(out);
|
953
|
+
}
|
954
|
+
}
|
955
|
+
|
956
|
+
simd8float32 load_simd8float32_partial(const float *x, int n) {
|
957
|
+
ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
958
|
+
float *wp = tmp;
|
959
|
+
for (int i = 0; i < n; i++) {
|
960
|
+
*wp++ = *x++;
|
961
|
+
}
|
962
|
+
return simd8float32(tmp);
|
963
|
+
}
|
964
|
+
|
965
|
+
} // anonymous namespace
|
966
|
+
|
967
|
+
|
968
|
+
|
969
|
+
|
970
|
+
void compute_PQ_dis_tables_dsub2(
|
971
|
+
size_t d, size_t ksub, const float *all_centroids,
|
972
|
+
size_t nx, const float * x,
|
973
|
+
bool is_inner_product,
|
974
|
+
float * dis_tables)
|
975
|
+
{
|
976
|
+
size_t M = d / 2;
|
977
|
+
FAISS_THROW_IF_NOT(ksub % 8 == 0);
|
978
|
+
|
979
|
+
for(size_t m0 = 0; m0 < M; m0 += 4) {
|
980
|
+
int m1 = std::min(M, m0 + 4);
|
981
|
+
for(int k0 = 0; k0 < ksub; k0 += 8) {
|
982
|
+
|
983
|
+
simd8float32 centroids[8];
|
984
|
+
for (int k = 0; k < 8; k++) {
|
985
|
+
float centroid[8] __attribute__((aligned(32)));
|
986
|
+
size_t wp = 0;
|
987
|
+
size_t rp = (m0 * ksub + k + k0) * 2;
|
988
|
+
for (int m = m0; m < m1; m++) {
|
989
|
+
centroid[wp++] = all_centroids[rp];
|
990
|
+
centroid[wp++] = all_centroids[rp + 1];
|
991
|
+
rp += 2 * ksub;
|
992
|
+
}
|
993
|
+
centroids[k] = simd8float32(centroid);
|
994
|
+
}
|
995
|
+
for(size_t i = 0; i < nx; i++) {
|
996
|
+
simd8float32 xi;
|
997
|
+
if (m1 == m0 + 4) {
|
998
|
+
xi.loadu(x + i * d + m0 * 2);
|
999
|
+
} else {
|
1000
|
+
xi = load_simd8float32_partial(x + i * d + m0 * 2, 2 * (m1 - m0));
|
1001
|
+
}
|
1002
|
+
|
1003
|
+
if(is_inner_product) {
|
1004
|
+
pq2_8cents_table<true>(
|
1005
|
+
centroids, xi,
|
1006
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
1007
|
+
ksub, m1 - m0
|
1008
|
+
);
|
1009
|
+
} else {
|
1010
|
+
pq2_8cents_table<false>(
|
1011
|
+
centroids, xi,
|
1012
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
1013
|
+
ksub, m1 - m0
|
1014
|
+
);
|
1015
|
+
}
|
1016
|
+
}
|
1017
|
+
}
|
1018
|
+
}
|
1019
|
+
|
1020
|
+
}
|
1021
|
+
|
1022
|
+
#else
|
1023
|
+
|
1024
|
+
void compute_PQ_dis_tables_dsub2(
|
1025
|
+
size_t d, size_t ksub, const float *all_centroids,
|
1026
|
+
size_t nx, const float * x,
|
1027
|
+
bool is_inner_product,
|
1028
|
+
float * dis_tables)
|
1029
|
+
{
|
1030
|
+
FAISS_THROW_MSG("only implemented for AVX2");
|
1031
|
+
}
|
1032
|
+
|
1033
|
+
#endif
|
806
1034
|
|
807
1035
|
|
808
1036
|
} // namespace faiss
|