faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -21,63 +21,30 @@ namespace faiss {
21
21
 
22
22
  struct IndexHNSW;
23
23
 
24
- struct ReconstructFromNeighbors {
25
- typedef HNSW::storage_idx_t storage_idx_t;
26
-
27
- const IndexHNSW& index;
28
- size_t M; // number of neighbors
29
- size_t k; // number of codebook entries
30
- size_t nsq; // number of subvectors
31
- size_t code_size;
32
- int k_reorder; // nb to reorder. -1 = all
33
-
34
- std::vector<float> codebook; // size nsq * k * (M + 1)
35
-
36
- std::vector<uint8_t> codes; // size ntotal * code_size
37
- size_t ntotal;
38
- size_t d, dsub; // derived values
39
-
40
- explicit ReconstructFromNeighbors(
41
- const IndexHNSW& index,
42
- size_t k = 256,
43
- size_t nsq = 1);
44
-
45
- /// codes must be added in the correct order and the IndexHNSW
46
- /// must be populated and sorted
47
- void add_codes(size_t n, const float* x);
48
-
49
- size_t compute_distances(
50
- size_t n,
51
- const idx_t* shortlist,
52
- const float* query,
53
- float* distances) const;
54
-
55
- /// called by add_codes
56
- void estimate_code(const float* x, storage_idx_t i, uint8_t* code) const;
57
-
58
- /// called by compute_distances
59
- void reconstruct(storage_idx_t i, float* x, float* tmp) const;
60
-
61
- void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float* x) const;
62
-
63
- /// get the M+1 -by-d table for neighbor coordinates for vector i
64
- void get_neighbor_table(storage_idx_t i, float* out) const;
65
- };
66
-
67
24
  /** The HNSW index is a normal random-access index with a HNSW
68
25
  * link structure built on top */
69
26
 
70
27
  struct IndexHNSW : Index {
71
28
  typedef HNSW::storage_idx_t storage_idx_t;
72
29
 
73
- // the link strcuture
30
+ // the link structure
74
31
  HNSW hnsw;
75
32
 
76
33
  // the sequential storage
77
- bool own_fields;
78
- Index* storage;
34
+ bool own_fields = false;
35
+ Index* storage = nullptr;
79
36
 
80
- ReconstructFromNeighbors* reconstruct_from_neighbors;
37
+ // When set to false, level 0 in the knn graph is not initialized.
38
+ // This option is used by GpuIndexCagra::copyTo(IndexHNSWCagra*)
39
+ // as level 0 knn graph is copied over from the index built by
40
+ // GpuIndexCagra.
41
+ bool init_level0 = true;
42
+
43
+ // When set to true, all neighbors in level 0 are filled up
44
+ // to the maximum size allowed (2 * M). This option is used by
45
+ // IndexHHNSWCagra to create a full base layer graph that is
46
+ // used when GpuIndexCagra::copyFrom(IndexHNSWCagra*) is invoked.
47
+ bool keep_max_size_level0 = false;
81
48
 
82
49
  explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
83
50
  explicit IndexHNSW(Index* storage, int M = 32);
@@ -98,6 +65,13 @@ struct IndexHNSW : Index {
98
65
  idx_t* labels,
99
66
  const SearchParameters* params = nullptr) const override;
100
67
 
68
+ void range_search(
69
+ idx_t n,
70
+ const float* x,
71
+ float radius,
72
+ RangeSearchResult* result,
73
+ const SearchParameters* params = nullptr) const override;
74
+
101
75
  void reconstruct(idx_t key, float* recons) const override;
102
76
 
103
77
  void reset() override;
@@ -119,7 +93,8 @@ struct IndexHNSW : Index {
119
93
  float* distances,
120
94
  idx_t* labels,
121
95
  int nprobe = 1,
122
- int search_type = 1) const;
96
+ int search_type = 1,
97
+ const SearchParameters* params = nullptr) const;
123
98
 
124
99
  /// alternative graph building
125
100
  void init_level_0_from_knngraph(int k, const float* D, const idx_t* I);
@@ -134,6 +109,10 @@ struct IndexHNSW : Index {
134
109
  void reorder_links();
135
110
 
136
111
  void link_singletons();
112
+
113
+ void permute_entries(const idx_t* perm);
114
+
115
+ DistanceComputer* get_distance_computer() const override;
137
116
  };
138
117
 
139
118
  /** Flat index topped with with a HNSW structure to access elements
@@ -150,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW {
150
129
  */
151
130
  struct IndexHNSWPQ : IndexHNSW {
152
131
  IndexHNSWPQ();
153
- IndexHNSWPQ(int d, int pq_m, int M);
132
+ IndexHNSWPQ(
133
+ int d,
134
+ int pq_m,
135
+ int M,
136
+ int pq_nbits = 8,
137
+ MetricType metric = METRIC_L2);
154
138
  void train(idx_t n, const float* x) override;
155
139
  };
156
140
 
@@ -184,4 +168,33 @@ struct IndexHNSW2Level : IndexHNSW {
184
168
  const SearchParameters* params = nullptr) const override;
185
169
  };
186
170
 
171
+ struct IndexHNSWCagra : IndexHNSW {
172
+ IndexHNSWCagra();
173
+ IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2);
174
+
175
+ /// When set to true, the index is immutable.
176
+ /// This option is used to copy the knn graph from GpuIndexCagra
177
+ /// to the base level of IndexHNSWCagra without adding upper levels.
178
+ /// Doing so enables to search the HNSW index, but removes the
179
+ /// ability to add vectors.
180
+ bool base_level_only = false;
181
+
182
+ /// When `base_level_only` is set to `True`, the search function
183
+ /// searches only the base level knn graph of the HNSW index.
184
+ /// This parameter selects the entry point by randomly selecting
185
+ /// some points and using the best one.
186
+ int num_base_level_search_entrypoints = 32;
187
+
188
+ void add(idx_t n, const float* x) override;
189
+
190
+ /// entry point for search
191
+ void search(
192
+ idx_t n,
193
+ const float* x,
194
+ idx_t k,
195
+ float* distances,
196
+ idx_t* labels,
197
+ const SearchParameters* params = nullptr) const override;
198
+ };
199
+
187
200
  } // namespace faiss
@@ -9,31 +9,43 @@
9
9
 
10
10
  #include <faiss/IndexIDMap.h>
11
11
 
12
- #include <stdint.h>
13
12
  #include <cinttypes>
13
+ #include <cstdint>
14
14
  #include <cstdio>
15
15
  #include <limits>
16
16
 
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
18
  #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/impl/IDSelector.h>
20
19
  #include <faiss/utils/Heap.h>
21
20
  #include <faiss/utils/WorkerThread.h>
22
21
 
23
22
  namespace faiss {
24
23
 
24
+ namespace {
25
+
26
+ // IndexBinary needs to update the code_size when d is set...
27
+
28
+ void sync_d(Index* index) {}
29
+
30
+ void sync_d(IndexBinary* index) {
31
+ FAISS_THROW_IF_NOT(index->d % 8 == 0);
32
+ index->code_size = index->d / 8;
33
+ }
34
+
35
+ } // anonymous namespace
36
+
25
37
  /*****************************************************
26
38
  * IndexIDMap implementation
27
39
  *******************************************************/
28
40
 
29
41
  template <typename IndexT>
30
- IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
31
- : index(index), own_fields(false) {
42
+ IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index) : index(index) {
32
43
  FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
33
44
  this->is_trained = index->is_trained;
34
45
  this->metric_type = index->metric_type;
35
46
  this->verbose = index->verbose;
36
47
  this->d = index->d;
48
+ sync_d(this);
37
49
  }
38
50
 
39
51
  template <typename IndexT>
@@ -71,6 +83,27 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
71
83
  this->ntotal = index->ntotal;
72
84
  }
73
85
 
86
+ namespace {
87
+
88
+ /// RAII object to reset the IDSelector in the params object
89
+ struct ScopedSelChange {
90
+ SearchParameters* params = nullptr;
91
+ IDSelector* old_sel = nullptr;
92
+
93
+ void set(SearchParameters* params_2, IDSelector* new_sel) {
94
+ this->params = params_2;
95
+ old_sel = params_2->sel;
96
+ params_2->sel = new_sel;
97
+ }
98
+ ~ScopedSelChange() {
99
+ if (params) {
100
+ params->sel = old_sel;
101
+ }
102
+ }
103
+ };
104
+
105
+ } // namespace
106
+
74
107
  template <typename IndexT>
75
108
  void IndexIDMapTemplate<IndexT>::search(
76
109
  idx_t n,
@@ -79,9 +112,26 @@ void IndexIDMapTemplate<IndexT>::search(
79
112
  typename IndexT::distance_t* distances,
80
113
  idx_t* labels,
81
114
  const SearchParameters* params) const {
82
- FAISS_THROW_IF_NOT_MSG(
83
- !params, "search params not supported for this index");
84
- index->search(n, x, k, distances, labels);
115
+ IDSelectorTranslated this_idtrans(this->id_map, nullptr);
116
+ ScopedSelChange sel_change;
117
+
118
+ if (params && params->sel) {
119
+ auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
120
+
121
+ if (!idtrans) {
122
+ /*
123
+ FAISS_THROW_IF_NOT_MSG(
124
+ idtrans,
125
+ "IndexIDMap requires an IDSelectorTranslated on input");
126
+ */
127
+ // then make an idtrans and force it into the SearchParameters
128
+ // (hence the const_cast)
129
+ auto params_non_const = const_cast<SearchParameters*>(params);
130
+ this_idtrans.sel = params->sel;
131
+ sel_change.set(params_non_const, &this_idtrans);
132
+ }
133
+ }
134
+ index->search(n, x, k, distances, labels, params);
85
135
  idx_t* li = labels;
86
136
  #pragma omp parallel for
87
137
  for (idx_t i = 0; i < n * k; i++) {
@@ -96,9 +146,16 @@ void IndexIDMapTemplate<IndexT>::range_search(
96
146
  typename IndexT::distance_t radius,
97
147
  RangeSearchResult* result,
98
148
  const SearchParameters* params) const {
99
- FAISS_THROW_IF_NOT_MSG(
100
- !params, "search params not supported for this index");
101
- index->range_search(n, x, radius, result);
149
+ if (params) {
150
+ SearchParameters internal_search_parameters;
151
+ IDSelectorTranslated id_selector_translated(id_map, params->sel);
152
+ internal_search_parameters.sel = &id_selector_translated;
153
+
154
+ index->range_search(n, x, radius, result, &internal_search_parameters);
155
+ } else {
156
+ index->range_search(n, x, radius, result);
157
+ }
158
+
102
159
  #pragma omp parallel for
103
160
  for (idx_t i = 0; i < result->lims[result->nq]; i++) {
104
161
  result->labels[i] = result->labels[i] < 0 ? result->labels[i]
@@ -106,26 +163,10 @@ void IndexIDMapTemplate<IndexT>::range_search(
106
163
  }
107
164
  }
108
165
 
109
- namespace {
110
-
111
- struct IDTranslatedSelector : IDSelector {
112
- const std::vector<int64_t>& id_map;
113
- const IDSelector& sel;
114
- IDTranslatedSelector(
115
- const std::vector<int64_t>& id_map,
116
- const IDSelector& sel)
117
- : id_map(id_map), sel(sel) {}
118
- bool is_member(idx_t id) const override {
119
- return sel.is_member(id_map[id]);
120
- }
121
- };
122
-
123
- } // namespace
124
-
125
166
  template <typename IndexT>
126
167
  size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127
168
  // remove in sub-index first
128
- IDTranslatedSelector sel2(id_map, sel);
169
+ IDSelectorTranslated sel2(id_map, &sel);
129
170
  size_t nremove = index->remove_ids(sel2);
130
171
 
131
172
  int64_t j = 0;
@@ -232,7 +273,7 @@ void IndexIDMap2Template<IndexT>::reconstruct(
232
273
  typename IndexT::component_t* recons) const {
233
274
  try {
234
275
  this->index->reconstruct(rev_map.at(key), recons);
235
- } catch (const std::out_of_range& e) {
276
+ } catch (const std::out_of_range&) {
236
277
  FAISS_THROW_FMT("key %" PRId64 " not found", key);
237
278
  }
238
279
  }
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/Index.h>
11
11
  #include <faiss/IndexBinary.h>
12
+ #include <faiss/impl/IDSelector.h>
12
13
 
13
14
  #include <unordered_map>
14
15
  #include <vector>
@@ -21,8 +22,8 @@ struct IndexIDMapTemplate : IndexT {
21
22
  using component_t = typename IndexT::component_t;
22
23
  using distance_t = typename IndexT::distance_t;
23
24
 
24
- IndexT* index; ///! the sub-index
25
- bool own_fields; ///! whether pointers are deleted in destructo
25
+ IndexT* index = nullptr; ///! the sub-index
26
+ bool own_fields = false; ///! whether pointers are deleted in destructo
26
27
  std::vector<idx_t> id_map;
27
28
 
28
29
  explicit IndexIDMapTemplate(IndexT* index);
@@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
102
103
  using IndexIDMap2 = IndexIDMap2Template<Index>;
103
104
  using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
104
105
 
106
+ // IDSelector that translates the ids using an IDMap
107
+ struct IDSelectorTranslated : IDSelector {
108
+ const std::vector<int64_t>& id_map;
109
+ const IDSelector* sel;
110
+
111
+ IDSelectorTranslated(
112
+ const std::vector<int64_t>& id_map,
113
+ const IDSelector* sel)
114
+ : id_map(id_map), sel(sel) {}
115
+
116
+ IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel)
117
+ : id_map(index_idmap.id_map), sel(sel) {}
118
+
119
+ IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel)
120
+ : id_map(index_idmap.id_map), sel(sel) {}
121
+
122
+ bool is_member(idx_t id) const override {
123
+ return sel->is_member(id_map[id]);
124
+ }
125
+ };
126
+
105
127
  } // namespace faiss