faiss 0.3.4 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +11 -8
  5. data/vendor/faiss/faiss/Clustering.cpp +0 -16
  6. data/vendor/faiss/faiss/IVFlib.cpp +213 -0
  7. data/vendor/faiss/faiss/IVFlib.h +42 -0
  8. data/vendor/faiss/faiss/Index.h +1 -1
  9. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -7
  10. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -1
  11. data/vendor/faiss/faiss/IndexFlatCodes.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFlatCodes.h +4 -2
  13. data/vendor/faiss/faiss/IndexHNSW.cpp +13 -20
  14. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  15. data/vendor/faiss/faiss/IndexIVF.cpp +20 -3
  16. data/vendor/faiss/faiss/IndexIVF.h +5 -2
  17. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -1
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +2 -1
  19. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -1
  20. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -1
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +277 -0
  24. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +70 -0
  25. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -1
  26. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  27. data/vendor/faiss/faiss/IndexRaBitQ.cpp +148 -0
  28. data/vendor/faiss/faiss/IndexRaBitQ.h +65 -0
  29. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -1
  30. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -1
  31. data/vendor/faiss/faiss/clone_index.cpp +38 -3
  32. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +19 -0
  33. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +4 -11
  34. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -1
  35. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +13 -3
  36. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  37. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +1 -1
  38. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +112 -0
  39. data/vendor/faiss/faiss/impl/HNSW.cpp +35 -13
  40. data/vendor/faiss/faiss/impl/HNSW.h +5 -4
  41. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  42. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +519 -0
  43. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +78 -0
  44. data/vendor/faiss/faiss/impl/ResultHandler.h +2 -2
  45. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +3 -4
  46. data/vendor/faiss/faiss/impl/index_read.cpp +220 -25
  47. data/vendor/faiss/faiss/impl/index_write.cpp +29 -0
  48. data/vendor/faiss/faiss/impl/io.h +2 -2
  49. data/vendor/faiss/faiss/impl/io_macros.h +2 -0
  50. data/vendor/faiss/faiss/impl/mapped_io.cpp +313 -0
  51. data/vendor/faiss/faiss/impl/mapped_io.h +51 -0
  52. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +316 -0
  53. data/vendor/faiss/faiss/impl/platform_macros.h +7 -3
  54. data/vendor/faiss/faiss/impl/simd_result_handlers.h +1 -1
  55. data/vendor/faiss/faiss/impl/zerocopy_io.cpp +67 -0
  56. data/vendor/faiss/faiss/impl/zerocopy_io.h +32 -0
  57. data/vendor/faiss/faiss/index_factory.cpp +16 -5
  58. data/vendor/faiss/faiss/index_io.h +4 -0
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +3 -3
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +5 -3
  61. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +3 -3
  62. data/vendor/faiss/faiss/python/python_callbacks.cpp +24 -0
  63. data/vendor/faiss/faiss/python/python_callbacks.h +22 -0
  64. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +30 -12
  65. data/vendor/faiss/faiss/utils/hamming.cpp +45 -21
  66. data/vendor/faiss/faiss/utils/hamming.h +7 -3
  67. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +1 -1
  68. data/vendor/faiss/faiss/utils/utils.cpp +4 -4
  69. data/vendor/faiss/faiss/utils/utils.h +3 -3
  70. metadata +16 -4
@@ -11,7 +11,7 @@
11
11
  #include <cstdint>
12
12
  #include <cstdio>
13
13
 
14
- #ifdef _MSC_VER
14
+ #ifdef _WIN32
15
15
 
16
16
  /*******************************************************
17
17
  * Windows specific macros
@@ -23,11 +23,11 @@
23
23
  #define FAISS_API __declspec(dllimport)
24
24
  #endif // FAISS_MAIN_LIB
25
25
 
26
- #ifdef _MSC_VER
27
26
  #define strtok_r strtok_s
28
- #endif // _MSC_VER
29
27
 
28
+ #ifdef _MSC_VER
30
29
  #define __PRETTY_FUNCTION__ __FUNCSIG__
30
+ #endif // _MSC_VER
31
31
 
32
32
  #define posix_memalign(p, a, s) \
33
33
  (((*(p)) = _aligned_malloc((s), (a))), *(p) ? 0 : errno)
@@ -37,6 +37,7 @@
37
37
  #define ALIGNED(x) __declspec(align(x))
38
38
 
39
39
  // redefine the GCC intrinsics with Windows equivalents
40
+ #ifdef _MSC_VER
40
41
 
41
42
  #include <intrin.h>
42
43
  #include <limits.h>
@@ -75,6 +76,7 @@ inline int __builtin_clzll(uint64_t x) {
75
76
 
76
77
  #define __builtin_popcount __popcnt
77
78
  #define __builtin_popcountl __popcnt64
79
+ #define __builtin_popcountll __popcnt64
78
80
 
79
81
  #ifndef __clang__
80
82
  #define __m128i_u __m128i
@@ -101,6 +103,8 @@ inline int __builtin_clzll(uint64_t x) {
101
103
  #define __F16C__ 1
102
104
  #endif
103
105
 
106
+ #endif // _MSC_VER
107
+
104
108
  #define FAISS_ALWAYS_INLINE __forceinline
105
109
 
106
110
  #else
@@ -576,7 +576,7 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
576
576
  normalizers = norms;
577
577
  for (int q = 0; q < nq; ++q) {
578
578
  thresholds[q] =
579
- normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
579
+ int(normalizers[2 * q] * (radius - normalizers[2 * q + 1]));
580
580
  }
581
581
  }
582
582
 
@@ -0,0 +1,67 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/zerocopy_io.h>
9
+ #include <cstring>
10
+
11
+ namespace faiss {
12
+
13
+ ZeroCopyIOReader::ZeroCopyIOReader(uint8_t* data, size_t size)
14
+ : data_(data), rp_(0), total_(size) {}
15
+
16
+ ZeroCopyIOReader::~ZeroCopyIOReader() {}
17
+
18
+ size_t ZeroCopyIOReader::get_data_view(void** ptr, size_t size, size_t nitems) {
19
+ if (size == 0) {
20
+ return nitems;
21
+ }
22
+
23
+ size_t actual_size = size * nitems;
24
+ if (rp_ + size * nitems > total_) {
25
+ actual_size = total_ - rp_;
26
+ }
27
+
28
+ size_t actual_nitems = (actual_size + size - 1) / size;
29
+ if (actual_nitems == 0) {
30
+ return 0;
31
+ }
32
+
33
+ // get an address
34
+ *ptr = (void*)(reinterpret_cast<const char*>(data_ + rp_));
35
+
36
+ // alter pos
37
+ rp_ += size * actual_nitems;
38
+
39
+ return actual_nitems;
40
+ }
41
+
42
+ void ZeroCopyIOReader::reset() {
43
+ rp_ = 0;
44
+ }
45
+
46
+ size_t ZeroCopyIOReader::operator()(void* ptr, size_t size, size_t nitems) {
47
+ if (size * nitems == 0) {
48
+ return 0;
49
+ }
50
+
51
+ if (rp_ >= total_) {
52
+ return 0;
53
+ }
54
+ size_t nremain = (total_ - rp_) / size;
55
+ if (nremain < nitems) {
56
+ nitems = nremain;
57
+ }
58
+ memcpy(ptr, (data_ + rp_), size * nitems);
59
+ rp_ += size * nitems;
60
+ return nitems;
61
+ }
62
+
63
+ int ZeroCopyIOReader::filedescriptor() {
64
+ return -1; // Indicating no file descriptor available for memory buffer
65
+ }
66
+
67
+ } // namespace faiss
@@ -0,0 +1,32 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstdint>
11
+
12
+ #include <faiss/impl/io.h>
13
+
14
+ namespace faiss {
15
+
16
+ // ZeroCopyIOReader just maps the data from a given pointer.
17
+ struct ZeroCopyIOReader : public faiss::IOReader {
18
+ uint8_t* data_;
19
+ size_t rp_ = 0;
20
+ size_t total_ = 0;
21
+
22
+ ZeroCopyIOReader(uint8_t* data, size_t size);
23
+ ~ZeroCopyIOReader();
24
+
25
+ void reset();
26
+ size_t get_data_view(void** ptr, size_t size, size_t nitems);
27
+ size_t operator()(void* ptr, size_t size, size_t nitems) override;
28
+
29
+ int filedescriptor() override;
30
+ };
31
+
32
+ } // namespace faiss
@@ -11,9 +11,6 @@
11
11
 
12
12
  #include <faiss/index_factory.h>
13
13
 
14
- #include <cinttypes>
15
- #include <cmath>
16
-
17
14
  #include <map>
18
15
 
19
16
  #include <regex>
@@ -33,6 +30,7 @@
33
30
  #include <faiss/IndexIVFPQ.h>
34
31
  #include <faiss/IndexIVFPQFastScan.h>
35
32
  #include <faiss/IndexIVFPQR.h>
33
+ #include <faiss/IndexIVFRaBitQ.h>
36
34
  #include <faiss/IndexIVFSpectralHash.h>
37
35
  #include <faiss/IndexLSH.h>
38
36
  #include <faiss/IndexLattice.h>
@@ -40,6 +38,7 @@
40
38
  #include <faiss/IndexPQ.h>
41
39
  #include <faiss/IndexPQFastScan.h>
42
40
  #include <faiss/IndexPreTransform.h>
41
+ #include <faiss/IndexRaBitQ.h>
43
42
  #include <faiss/IndexRefine.h>
44
43
  #include <faiss/IndexRowwiseMinMax.h>
45
44
  #include <faiss/IndexScalarQuantizer.h>
@@ -67,6 +66,7 @@ namespace {
67
66
  */
68
67
 
69
68
  bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
69
+ // @lint-ignore CLANGTIDY
70
70
  return std::regex_match(s, sm, std::regex(pat));
71
71
  }
72
72
 
@@ -164,7 +164,7 @@ const std::string aq_norm_pattern =
164
164
  const std::string paq_def_pattern = "([0-9]+)x([0-9]+)x([0-9]+)";
165
165
 
166
166
  AdditiveQuantizer::Search_type_t aq_parse_search_type(
167
- std::string stok,
167
+ const std::string& stok,
168
168
  MetricType metric) {
169
169
  if (stok == "") {
170
170
  return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
@@ -177,6 +177,7 @@ AdditiveQuantizer::Search_type_t aq_parse_search_type(
177
177
  std::vector<size_t> aq_parse_nbits(std::string stok) {
178
178
  std::vector<size_t> nbits;
179
179
  std::smatch sm;
180
+ // @lint-ignore CLANGTIDY
180
181
  while (std::regex_search(stok, sm, std::regex("[^q]([0-9]+)x([0-9]+)"))) {
181
182
  int M = std::stoi(sm[1].str());
182
183
  int nbit = std::stoi(sm[2].str());
@@ -186,6 +187,8 @@ std::vector<size_t> aq_parse_nbits(std::string stok) {
186
187
  return nbits;
187
188
  }
188
189
 
190
+ const std::string rabitq_pattern = "(RaBitQ)";
191
+
189
192
  /***************************************************************
190
193
  * Parse VectorTransform
191
194
  */
@@ -436,6 +439,9 @@ IndexIVF* parse_IndexIVF(
436
439
  }
437
440
  return index_ivf;
438
441
  }
442
+ if (match(rabitq_pattern)) {
443
+ return new IndexIVFRaBitQ(get_q(), d, nlist, mt);
444
+ }
439
445
  return nullptr;
440
446
  }
441
447
 
@@ -657,6 +663,11 @@ Index* parse_other_indexes(
657
663
  }
658
664
  }
659
665
 
666
+ // IndexRaBitQ
667
+ if (match(rabitq_pattern)) {
668
+ return new IndexRaBitQ(d, metric);
669
+ }
670
+
660
671
  return nullptr;
661
672
  }
662
673
 
@@ -766,7 +777,7 @@ std::unique_ptr<Index> index_factory_sub(
766
777
  }
767
778
 
768
779
  if (verbose) {
769
- printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
780
+ printf("after () normalization: %s %zd parenthesis indexes d=%d\n",
770
781
  description.c_str(),
771
782
  parenthesis_indexes.size(),
772
783
  d);
@@ -62,6 +62,10 @@ const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
62
62
  // try to memmap data (useful to load an ArrayInvertedLists as an
63
63
  // OnDiskInvertedLists)
64
64
  const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;
65
+ // mmap that handles codes for IndexFlatCodes-derived indices and HNSW.
66
+ // this is a temporary solution, it is expected to be merged with IO_FLAG_MMAP
67
+ // after OnDiskInvertedLists get properly updated.
68
+ const int IO_FLAG_MMAP_IFC = 1 << 9;
65
69
 
66
70
  Index* read_index(const char* fname, int io_flags = 0);
67
71
  Index* read_index(FILE* f, int io_flags = 0);
@@ -181,7 +181,7 @@ size_t InvertedLists::copy_subset_to(
181
181
  }
182
182
 
183
183
  double InvertedLists::imbalance_factor() const {
184
- std::vector<int> hist(nlist);
184
+ std::vector<int64_t> hist(nlist);
185
185
 
186
186
  for (size_t i = 0; i < nlist; i++) {
187
187
  hist[i] = list_size(i);
@@ -330,8 +330,8 @@ void ArrayInvertedLists::update_entries(
330
330
  }
331
331
 
332
332
  void ArrayInvertedLists::permute_invlists(const idx_t* map) {
333
- std::vector<std::vector<uint8_t>> new_codes(nlist);
334
- std::vector<std::vector<idx_t>> new_ids(nlist);
333
+ std::vector<MaybeOwnedVector<uint8_t>> new_codes(nlist);
334
+ std::vector<MaybeOwnedVector<idx_t>> new_ids(nlist);
335
335
 
336
336
  for (size_t i = 0; i < nlist; i++) {
337
337
  size_t o = map[i];
@@ -15,9 +15,11 @@
15
15
  * the interface.
16
16
  */
17
17
 
18
- #include <faiss/MetricType.h>
19
18
  #include <vector>
20
19
 
20
+ #include <faiss/MetricType.h>
21
+ #include <faiss/impl/maybe_owned_vector.h>
22
+
21
23
  namespace faiss {
22
24
 
23
25
  struct InvertedListsIterator {
@@ -241,8 +243,8 @@ struct InvertedLists {
241
243
 
242
244
  /// simple (default) implementation as an array of inverted lists
243
245
  struct ArrayInvertedLists : InvertedLists {
244
- std::vector<std::vector<uint8_t>> codes; // binary codes, size nlist
245
- std::vector<std::vector<idx_t>> ids; ///< Inverted lists for indexes
246
+ std::vector<MaybeOwnedVector<uint8_t>> codes; // binary codes, size nlist
247
+ std::vector<MaybeOwnedVector<idx_t>> ids; ///< Inverted lists for indexes
246
248
 
247
249
  ArrayInvertedLists(size_t nlist, size_t code_size);
248
250
 
@@ -13,9 +13,9 @@
13
13
 
14
14
  #include <faiss/invlists/BlockInvertedLists.h>
15
15
 
16
- #ifndef _MSC_VER
16
+ #ifndef _WIN32
17
17
  #include <faiss/invlists/OnDiskInvertedLists.h>
18
- #endif // !_MSC_VER
18
+ #endif // !_WIN32
19
19
 
20
20
  namespace faiss {
21
21
 
@@ -33,7 +33,7 @@ namespace {
33
33
  /// std::vector that deletes its contents
34
34
  struct IOHookTable : std::vector<InvertedListsIOHook*> {
35
35
  IOHookTable() {
36
- #ifndef _MSC_VER
36
+ #ifndef _WIN32
37
37
  push_back(new OnDiskInvertedListsIOHook());
38
38
  #endif
39
39
  push_back(new BlockInvertedListsIOHook());
@@ -134,3 +134,27 @@ PyCallbackIDSelector::~PyCallbackIDSelector() {
134
134
  PyThreadLock gil;
135
135
  Py_DECREF(callback);
136
136
  }
137
+
138
+ /***********************************************************
139
+ * Callbacks for IVF index sharding
140
+ ***********************************************************/
141
+
142
+ PyCallbackShardingFunction::PyCallbackShardingFunction(PyObject* callback)
143
+ : callback(callback) {
144
+ PyThreadLock gil;
145
+ Py_INCREF(callback);
146
+ }
147
+
148
+ int64_t PyCallbackShardingFunction::operator()(int64_t i, int64_t shard_count) {
149
+ PyThreadLock gil;
150
+ PyObject* shard_id = PyObject_CallFunction(callback, "LL", i, shard_count);
151
+ if (shard_id == nullptr) {
152
+ FAISS_THROW_MSG("propagate py error");
153
+ }
154
+ return PyLong_AsLongLong(shard_id);
155
+ }
156
+
157
+ PyCallbackShardingFunction::~PyCallbackShardingFunction() {
158
+ PyThreadLock gil;
159
+ Py_DECREF(callback);
160
+ }
@@ -7,6 +7,7 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <faiss/IVFlib.h>
10
11
  #include <faiss/impl/IDSelector.h>
11
12
  #include <faiss/impl/io.h>
12
13
  #include <faiss/invlists/InvertedLists.h>
@@ -58,3 +59,24 @@ struct PyCallbackIDSelector : faiss::IDSelector {
58
59
 
59
60
  ~PyCallbackIDSelector() override;
60
61
  };
62
+
63
+ /***********************************************************
64
+ * Callbacks for IVF index sharding
65
+ ***********************************************************/
66
+
67
+ struct PyCallbackShardingFunction : faiss::ivflib::ShardingFunction {
68
+ PyObject* callback;
69
+
70
+ explicit PyCallbackShardingFunction(PyObject* callback);
71
+
72
+ int64_t operator()(int64_t i, int64_t shard_count) override;
73
+
74
+ ~PyCallbackShardingFunction() override;
75
+
76
+ PyCallbackShardingFunction(const PyCallbackShardingFunction&) = delete;
77
+ PyCallbackShardingFunction(PyCallbackShardingFunction&&) noexcept = default;
78
+ PyCallbackShardingFunction& operator=(const PyCallbackShardingFunction&) =
79
+ default;
80
+ PyCallbackShardingFunction& operator=(PyCallbackShardingFunction&&) =
81
+ default;
82
+ };
@@ -46,9 +46,11 @@ struct HeapWithBucketsForHamming32<
46
46
  // output distances
47
47
  int* const __restrict bh_val,
48
48
  // output indices, each being within [0, n) range
49
- int64_t* const __restrict bh_ids) {
49
+ int64_t* const __restrict bh_ids,
50
+ // optional id selector for filtering
51
+ const IDSelector* sel = nullptr) {
50
52
  // forward a call to bs_addn with 1 beam
51
- bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
53
+ bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids, sel);
52
54
  }
53
55
 
54
56
  static void bs_addn(
@@ -66,7 +68,9 @@ struct HeapWithBucketsForHamming32<
66
68
  int* const __restrict bh_val,
67
69
  // output indices, each being within [0, n_per_beam * beam_size)
68
70
  // range
69
- int64_t* const __restrict bh_ids) {
71
+ int64_t* const __restrict bh_ids,
72
+ // optional id selector for filtering
73
+ const IDSelector* sel = nullptr) {
70
74
  //
71
75
  using C = CMax<int, int64_t>;
72
76
 
@@ -95,11 +99,22 @@ struct HeapWithBucketsForHamming32<
95
99
  for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
96
100
  for (uint32_t j = 0; j < NBUCKETS_8; j++) {
97
101
  uint32_t hamming_distances[8];
102
+ uint8_t valid_counter = 0;
98
103
  for (size_t j8 = 0; j8 < 8; j8++) {
99
- hamming_distances[j8] = hc.hamming(
100
- binary_vectors +
101
- (j8 + j * 8 + ip + n_per_beam * beam_index) *
102
- code_size);
104
+ const uint32_t idx =
105
+ j8 + j * 8 + ip + n_per_beam * beam_index;
106
+ if (!sel || sel->is_member(idx)) {
107
+ hamming_distances[j8] = hc.hamming(
108
+ binary_vectors + idx * code_size);
109
+ valid_counter++;
110
+ } else {
111
+ hamming_distances[j8] =
112
+ std::numeric_limits<int32_t>::max();
113
+ }
114
+ }
115
+
116
+ if (valid_counter == 8) {
117
+ continue; // Skip if all vectors are filtered out
103
118
  }
104
119
 
105
120
  // loop. Compiler should get rid of unneeded ops
@@ -157,7 +172,8 @@ struct HeapWithBucketsForHamming32<
157
172
  const auto value = min_distances_scalar[j8];
158
173
  const auto index = min_indices_scalar[j8];
159
174
 
160
- if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
175
+ if (value < std::numeric_limits<int32_t>::max() &&
176
+ C::cmp2(bh_val[0], value, bh_ids[0], index)) {
161
177
  heap_replace_top<C>(
162
178
  k, bh_val, bh_ids, value, index);
163
179
  }
@@ -168,11 +184,13 @@ struct HeapWithBucketsForHamming32<
168
184
  // process leftovers
169
185
  for (uint32_t ip = nb; ip < n_per_beam; ip++) {
170
186
  const auto index = ip + n_per_beam * beam_index;
171
- const auto value =
172
- hc.hamming(binary_vectors + (index)*code_size);
187
+ if (!sel || sel->is_member(index)) {
188
+ const auto value =
189
+ hc.hamming(binary_vectors + (index)*code_size);
173
190
 
174
- if (C::cmp(bh_val[0], value)) {
175
- heap_replace_top<C>(k, bh_val, bh_ids, value, index);
191
+ if (C::cmp(bh_val[0], value)) {
192
+ heap_replace_top<C>(k, bh_val, bh_ids, value, index);
193
+ }
176
194
  }
177
195
  }
178
196
  }
@@ -30,6 +30,7 @@
30
30
 
31
31
  #include <faiss/impl/AuxIndexStructures.h>
32
32
  #include <faiss/impl/FaissAssert.h>
33
+ #include <faiss/impl/IDSelector.h>
33
34
  #include <faiss/utils/Heap.h>
34
35
  #include <faiss/utils/approx_topk_hamming/approx_topk_hamming.h>
35
36
  #include <faiss/utils/utils.h>
@@ -62,15 +63,15 @@ void hammings(
62
63
  const uint64_t* __restrict bs2,
63
64
  size_t n1,
64
65
  size_t n2,
65
- size_t nwords,
66
+ size_t nbits,
66
67
  hamdis_t* __restrict dis) {
67
68
  size_t i, j;
68
- n1 *= nwords;
69
- n2 *= nwords;
70
- for (i = 0; i < n1; i += nwords) {
71
- const uint64_t* bs1_ = bs1 + i;
72
- for (j = 0; j < n2; j += nwords)
73
- dis[j] = hamming(bs1_, bs2 + j, nwords);
69
+ const size_t nwords = nbits / 64;
70
+ for (i = 0; i < n1; i++) {
71
+ const uint64_t* __restrict bs1_ = bs1 + i * nwords;
72
+ hamdis_t* __restrict dis_ = dis + i * n2;
73
+ for (j = 0; j < n2; j++)
74
+ dis_[j] = hamming(bs1_, bs2 + j * nwords, nwords);
74
75
  }
75
76
  }
76
77
 
@@ -171,7 +172,8 @@ void hammings_knn_hc(
171
172
  size_t n2,
172
173
  bool order = true,
173
174
  bool init_heap = true,
174
- ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK) {
175
+ ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
176
+ const faiss::IDSelector* sel = nullptr) {
175
177
  size_t k = ha->k;
176
178
  if (init_heap)
177
179
  ha->heapify();
@@ -204,7 +206,7 @@ void hammings_knn_hc(
204
206
  NB, \
205
207
  BD, \
206
208
  HammingComputer>:: \
207
- addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_); \
209
+ addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_, sel); \
208
210
  break;
209
211
 
210
212
  switch (approx_topk_mode) {
@@ -214,6 +216,9 @@ void hammings_knn_hc(
214
216
  HANDLE_APPROX(32, 2)
215
217
  default: {
216
218
  for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
219
+ if (sel && !sel->is_member(j)) {
220
+ continue;
221
+ }
217
222
  dis = hc.hamming(bs2_);
218
223
  if (dis < bh_val_[0]) {
219
224
  faiss::maxheap_replace_top<hamdis_t>(
@@ -238,7 +243,8 @@ void hammings_knn_mc(
238
243
  size_t nb,
239
244
  size_t k,
240
245
  int32_t* __restrict distances,
241
- int64_t* __restrict labels) {
246
+ int64_t* __restrict labels,
247
+ const faiss::IDSelector* sel) {
242
248
  const int nBuckets = bytes_per_code * 8 + 1;
243
249
  std::vector<int> all_counters(na * nBuckets, 0);
244
250
  std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
@@ -259,7 +265,9 @@ void hammings_knn_mc(
259
265
  #pragma omp parallel for
260
266
  for (int64_t i = 0; i < na; ++i) {
261
267
  for (size_t j = j0; j < j1; ++j) {
262
- cs[i].update_counter(b + j * bytes_per_code, j);
268
+ if (!sel || sel->is_member(j)) {
269
+ cs[i].update_counter(b + j * bytes_per_code, j);
270
+ }
263
271
  }
264
272
  }
265
273
  }
@@ -291,7 +299,8 @@ void hamming_range_search(
291
299
  size_t nb,
292
300
  int radius,
293
301
  size_t code_size,
294
- RangeSearchResult* res) {
302
+ RangeSearchResult* res,
303
+ const faiss::IDSelector* sel) {
295
304
  #pragma omp parallel
296
305
  {
297
306
  RangeSearchPartialResult pres(res);
@@ -303,9 +312,11 @@ void hamming_range_search(
303
312
  RangeQueryResult& qres = pres.new_result(i);
304
313
 
305
314
  for (size_t j = 0; j < nb; j++) {
306
- int dis = hc.hamming(yi);
307
- if (dis < radius) {
308
- qres.add(dis, j);
315
+ if (!sel || sel->is_member(j)) {
316
+ int dis = hc.hamming(yi);
317
+ if (dis < radius) {
318
+ qres.add(dis, j);
319
+ }
309
320
  }
310
321
  yi += code_size;
311
322
  }
@@ -489,10 +500,21 @@ void hammings_knn_hc(
489
500
  size_t nb,
490
501
  size_t ncodes,
491
502
  int order,
492
- ApproxTopK_mode_t approx_topk_mode) {
503
+ ApproxTopK_mode_t approx_topk_mode,
504
+ const faiss::IDSelector* sel) {
493
505
  Run_hammings_knn_hc r;
494
506
  dispatch_HammingComputer(
495
- ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode);
507
+ ncodes,
508
+ r,
509
+ ncodes,
510
+ ha,
511
+ a,
512
+ b,
513
+ nb,
514
+ order,
515
+ true,
516
+ approx_topk_mode,
517
+ sel);
496
518
  }
497
519
 
498
520
  void hammings_knn_mc(
@@ -503,10 +525,11 @@ void hammings_knn_mc(
503
525
  size_t k,
504
526
  size_t ncodes,
505
527
  int32_t* __restrict distances,
506
- int64_t* __restrict labels) {
528
+ int64_t* __restrict labels,
529
+ const faiss::IDSelector* sel) {
507
530
  Run_hammings_knn_mc r;
508
531
  dispatch_HammingComputer(
509
- ncodes, r, ncodes, a, b, na, nb, k, distances, labels);
532
+ ncodes, r, ncodes, a, b, na, nb, k, distances, labels, sel);
510
533
  }
511
534
 
512
535
  void hamming_range_search(
@@ -516,10 +539,11 @@ void hamming_range_search(
516
539
  size_t nb,
517
540
  int radius,
518
541
  size_t code_size,
519
- RangeSearchResult* result) {
542
+ RangeSearchResult* result,
543
+ const faiss::IDSelector* sel) {
520
544
  Run_hamming_range_search r;
521
545
  dispatch_HammingComputer(
522
- code_size, r, a, b, na, nb, radius, code_size, result);
546
+ code_size, r, a, b, na, nb, radius, code_size, result, sel);
523
547
  }
524
548
 
525
549
  /* Count number of matches given a max threshold */
@@ -27,6 +27,7 @@
27
27
 
28
28
  #include <stdint.h>
29
29
 
30
+ #include <faiss/impl/IDSelector.h>
30
31
  #include <faiss/impl/platform_macros.h>
31
32
  #include <faiss/utils/Heap.h>
32
33
 
@@ -135,7 +136,8 @@ void hammings_knn_hc(
135
136
  size_t nb,
136
137
  size_t ncodes,
137
138
  int ordered,
138
- ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK);
139
+ ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
140
+ const faiss::IDSelector* sel = nullptr);
139
141
 
140
142
  /* Legacy alias to hammings_knn_hc. */
141
143
  void hammings_knn(
@@ -166,7 +168,8 @@ void hammings_knn_mc(
166
168
  size_t k,
167
169
  size_t ncodes,
168
170
  int32_t* distances,
169
- int64_t* labels);
171
+ int64_t* labels,
172
+ const faiss::IDSelector* sel = nullptr);
170
173
 
171
174
  /** same as hammings_knn except we are doing a range search with radius */
172
175
  void hamming_range_search(
@@ -176,7 +179,8 @@ void hamming_range_search(
176
179
  size_t nb,
177
180
  int radius,
178
181
  size_t ncodes,
179
- RangeSearchResult* result);
182
+ RangeSearchResult* result,
183
+ const faiss::IDSelector* sel = nullptr);
180
184
 
181
185
  /* Counting the number of matches or of cross-matches (without returning them)
182
186
  For use with function that assume pre-allocated memory */
@@ -11,7 +11,7 @@
11
11
  // AVX512 version
12
12
  // The _mm512_popcnt_epi64 intrinsic is used to accelerate Hamming distance
13
13
  // calculations in HammingComputerDefault and HammingComputer64. This intrinsic
14
- // is not available in the default FAISS avx512 build mode but is only
14
+ // is not available in the default Faiss avx512 build mode but is only
15
15
  // available in the avx512_spr build mode, which targets Intel(R) Sapphire
16
16
  // Rapids.
17
17