faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -5,17 +5,14 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <faiss/IndexPQ.h>
12
11
  #include <faiss/impl/ProductQuantizer.h>
13
12
  #include <faiss/utils/AlignedTable.h>
14
13
 
15
-
16
14
  namespace faiss {
17
15
 
18
-
19
16
  /** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
20
17
  *
21
18
  * The codes are not stored sequentially but grouped in blocks of size bbs.
@@ -28,7 +25,7 @@ namespace faiss {
28
25
  * 15: no qbs with reservoir accumulator
29
26
  */
30
27
 
31
- struct IndexPQFastScan: Index {
28
+ struct IndexPQFastScan : Index {
32
29
  ProductQuantizer pq;
33
30
 
34
31
  // implementation to select
@@ -37,8 +34,8 @@ struct IndexPQFastScan: Index {
37
34
  int skip = 0;
38
35
 
39
36
  // size of the kernel
40
- int bbs; // set at build time
41
- int qbs = 0; // query block size 0 = use default
37
+ int bbs; // set at build time
38
+ int qbs = 0; // query block size 0 = use default
42
39
 
43
40
  // packed version of the codes
44
41
  size_t ntotal2;
@@ -47,22 +44,23 @@ struct IndexPQFastScan: Index {
47
44
  AlignedTable<uint8_t> codes;
48
45
 
49
46
  // this is for testing purposes only (set when initialized by IndexPQ)
50
- const uint8_t *orig_codes = nullptr;
47
+ const uint8_t* orig_codes = nullptr;
51
48
 
52
49
  IndexPQFastScan(
53
- int d, size_t M, size_t nbits,
54
- MetricType metric = METRIC_L2,
55
- int bbs = 32
56
- );
50
+ int d,
51
+ size_t M,
52
+ size_t nbits,
53
+ MetricType metric = METRIC_L2,
54
+ int bbs = 32);
57
55
 
58
56
  IndexPQFastScan();
59
57
 
60
58
  /// build from an existing IndexPQ
61
- explicit IndexPQFastScan(const IndexPQ & orig, int bbs = 32);
59
+ explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
62
60
 
63
- void train (idx_t n, const float *x) override;
64
- void add (idx_t n, const float *x) override;
65
- void reset() override ;
61
+ void train(idx_t n, const float* x) override;
62
+ void add(idx_t n, const float* x) override;
63
+ void reset() override;
66
64
  void search(
67
65
  idx_t n,
68
66
  const float* x,
@@ -72,35 +70,51 @@ struct IndexPQFastScan: Index {
72
70
 
73
71
  // called by search function
74
72
  void compute_quantized_LUT(
75
- idx_t n, const float* x,
76
- uint8_t *lut, float *normalizers) const ;
73
+ idx_t n,
74
+ const float* x,
75
+ uint8_t* lut,
76
+ float* normalizers) const;
77
77
 
78
- template<bool is_max>
78
+ template <bool is_max>
79
79
  void search_dispatch_implem(
80
- idx_t n, const float* x, idx_t k,
81
- float* distances, idx_t* labels) const;
80
+ idx_t n,
81
+ const float* x,
82
+ idx_t k,
83
+ float* distances,
84
+ idx_t* labels) const;
82
85
 
83
- template<class C>
86
+ template <class C>
84
87
  void search_implem_2(
85
- idx_t n, const float* x, idx_t k,
86
- float* distances, idx_t* labels) const;
87
-
88
+ idx_t n,
89
+ const float* x,
90
+ idx_t k,
91
+ float* distances,
92
+ idx_t* labels) const;
88
93
 
89
- template<class C>
94
+ template <class C>
90
95
  void search_implem_12(
91
- idx_t n, const float* x, idx_t k,
92
- float* distances, idx_t* labels, int impl) const;
96
+ idx_t n,
97
+ const float* x,
98
+ idx_t k,
99
+ float* distances,
100
+ idx_t* labels,
101
+ int impl) const;
93
102
 
94
- template<class C>
103
+ template <class C>
95
104
  void search_implem_14(
96
- idx_t n, const float* x, idx_t k,
97
- float* distances, idx_t* labels, int impl) const;
98
-
105
+ idx_t n,
106
+ const float* x,
107
+ idx_t k,
108
+ float* distances,
109
+ idx_t* labels,
110
+ int impl) const;
99
111
  };
100
112
 
101
113
  struct FastScanStats {
102
114
  uint64_t t0, t1, t2, t3;
103
- FastScanStats() {reset();}
115
+ FastScanStats() {
116
+ reset();
117
+ }
104
118
  void reset() {
105
119
  memset(this, 0, sizeof(*this));
106
120
  }
@@ -9,13 +9,13 @@
9
9
 
10
10
  #include <faiss/IndexPreTransform.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <cmath>
13
+ #include <cstdio>
14
14
  #include <cstring>
15
15
  #include <memory>
16
16
 
17
- #include <faiss/impl/FaissAssert.h>
18
17
  #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
19
 
20
20
  namespace faiss {
21
21
 
@@ -23,44 +23,29 @@ namespace faiss {
23
23
  * IndexPreTransform
24
24
  *********************************************/
25
25
 
26
- IndexPreTransform::IndexPreTransform ():
27
- index(nullptr), own_fields (false)
28
- {
29
- }
30
-
26
+ IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {}
31
27
 
32
- IndexPreTransform::IndexPreTransform (
33
- Index * index):
34
- Index (index->d, index->metric_type),
35
- index (index), own_fields (false)
36
- {
28
+ IndexPreTransform::IndexPreTransform(Index* index)
29
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
37
30
  is_trained = index->is_trained;
38
31
  ntotal = index->ntotal;
39
32
  }
40
33
 
41
-
42
- IndexPreTransform::IndexPreTransform (
43
- VectorTransform * ltrans,
44
- Index * index):
45
- Index (index->d, index->metric_type),
46
- index (index), own_fields (false)
47
- {
34
+ IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index)
35
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
48
36
  is_trained = index->is_trained;
49
37
  ntotal = index->ntotal;
50
- prepend_transform (ltrans);
38
+ prepend_transform(ltrans);
51
39
  }
52
40
 
53
- void IndexPreTransform::prepend_transform (VectorTransform *ltrans)
54
- {
55
- FAISS_THROW_IF_NOT (ltrans->d_out == d);
41
+ void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
42
+ FAISS_THROW_IF_NOT(ltrans->d_out == d);
56
43
  is_trained = is_trained && ltrans->is_trained;
57
- chain.insert (chain.begin(), ltrans);
44
+ chain.insert(chain.begin(), ltrans);
58
45
  d = ltrans->d_in;
59
46
  }
60
47
 
61
-
62
- IndexPreTransform::~IndexPreTransform ()
63
- {
48
+ IndexPreTransform::~IndexPreTransform() {
64
49
  if (own_fields) {
65
50
  for (int i = 0; i < chain.size(); i++)
66
51
  delete chain[i];
@@ -68,11 +53,7 @@ IndexPreTransform::~IndexPreTransform ()
68
53
  }
69
54
  }
70
55
 
71
-
72
-
73
-
74
- void IndexPreTransform::train (idx_t n, const float *x)
75
- {
56
+ void IndexPreTransform::train(idx_t n, const float* x) {
76
57
  int last_untrained = 0;
77
58
  if (!index->is_trained) {
78
59
  last_untrained = chain.size();
@@ -84,7 +65,7 @@ void IndexPreTransform::train (idx_t n, const float *x)
84
65
  }
85
66
  }
86
67
  }
87
- const float *prev_x = x;
68
+ const float* prev_x = x;
88
69
  ScopeDeleter<float> del;
89
70
 
90
71
  if (verbose) {
@@ -93,34 +74,35 @@ void IndexPreTransform::train (idx_t n, const float *x)
93
74
  }
94
75
 
95
76
  for (int i = 0; i <= last_untrained; i++) {
96
-
97
77
  if (i < chain.size()) {
98
- VectorTransform *ltrans = chain [i];
78
+ VectorTransform* ltrans = chain[i];
99
79
  if (!ltrans->is_trained) {
100
80
  if (verbose) {
101
81
  printf(" Training chain component %d/%zd\n",
102
- i, chain.size());
103
- if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
82
+ i,
83
+ chain.size());
84
+ if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
104
85
  opqm->verbose = true;
105
86
  }
106
87
  }
107
- ltrans->train (n, prev_x);
88
+ ltrans->train(n, prev_x);
108
89
  }
109
90
  } else {
110
91
  if (verbose) {
111
92
  printf(" Training sub-index\n");
112
93
  }
113
- index->train (n, prev_x);
94
+ index->train(n, prev_x);
114
95
  }
115
- if (i == last_untrained) break;
96
+ if (i == last_untrained)
97
+ break;
116
98
  if (verbose) {
117
- printf(" Applying transform %d/%zd\n",
118
- i, chain.size());
99
+ printf(" Applying transform %d/%zd\n", i, chain.size());
119
100
  }
120
101
 
121
- float * xt = chain[i]->apply (n, prev_x);
102
+ float* xt = chain[i]->apply(n, prev_x);
122
103
 
123
- if (prev_x != x) delete [] prev_x;
104
+ if (prev_x != x)
105
+ delete[] prev_x;
124
106
  prev_x = xt;
125
107
  del.set(xt);
126
108
  }
@@ -128,200 +110,190 @@ void IndexPreTransform::train (idx_t n, const float *x)
128
110
  is_trained = true;
129
111
  }
130
112
 
131
-
132
- const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
133
- {
134
- const float *prev_x = x;
113
+ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
114
+ const float* prev_x = x;
135
115
  ScopeDeleter<float> del;
136
116
 
137
117
  for (int i = 0; i < chain.size(); i++) {
138
- float * xt = chain[i]->apply (n, prev_x);
139
- ScopeDeleter<float> del2 (xt);
140
- del2.swap (del);
118
+ float* xt = chain[i]->apply(n, prev_x);
119
+ ScopeDeleter<float> del2(xt);
120
+ del2.swap(del);
141
121
  prev_x = xt;
142
122
  }
143
- del.release ();
123
+ del.release();
144
124
  return prev_x;
145
125
  }
146
126
 
147
- void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const
148
- {
127
+ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
128
+ const {
149
129
  const float* next_x = xt;
150
130
  ScopeDeleter<float> del;
151
131
 
152
132
  for (int i = chain.size() - 1; i >= 0; i--) {
153
- float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in];
154
- ScopeDeleter<float> del2 ((prev_x == x) ? nullptr : prev_x);
155
- chain [i]->reverse_transform (n, next_x, prev_x);
156
- del2.swap (del);
133
+ float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
134
+ ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
135
+ chain[i]->reverse_transform(n, next_x, prev_x);
136
+ del2.swap(del);
157
137
  next_x = prev_x;
158
138
  }
159
139
  }
160
140
 
161
- void IndexPreTransform::add (idx_t n, const float *x)
162
- {
163
- FAISS_THROW_IF_NOT (is_trained);
164
- const float *xt = apply_chain (n, x);
141
+ void IndexPreTransform::add(idx_t n, const float* x) {
142
+ FAISS_THROW_IF_NOT(is_trained);
143
+ const float* xt = apply_chain(n, x);
165
144
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
166
- index->add (n, xt);
145
+ index->add(n, xt);
167
146
  ntotal = index->ntotal;
168
147
  }
169
148
 
170
- void IndexPreTransform::add_with_ids (idx_t n, const float * x,
171
- const idx_t *xids)
172
- {
173
- FAISS_THROW_IF_NOT (is_trained);
174
- const float *xt = apply_chain (n, x);
149
+ void IndexPreTransform::add_with_ids(
150
+ idx_t n,
151
+ const float* x,
152
+ const idx_t* xids) {
153
+ FAISS_THROW_IF_NOT(is_trained);
154
+ const float* xt = apply_chain(n, x);
175
155
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
176
- index->add_with_ids (n, xt, xids);
156
+ index->add_with_ids(n, xt, xids);
177
157
  ntotal = index->ntotal;
178
158
  }
179
159
 
160
+ void IndexPreTransform::search(
161
+ idx_t n,
162
+ const float* x,
163
+ idx_t k,
164
+ float* distances,
165
+ idx_t* labels) const {
166
+ FAISS_THROW_IF_NOT(k > 0);
180
167
 
181
-
182
-
183
- void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
184
- float *distances, idx_t *labels) const
185
- {
186
- FAISS_THROW_IF_NOT (is_trained);
187
- const float *xt = apply_chain (n, x);
168
+ FAISS_THROW_IF_NOT(is_trained);
169
+ const float* xt = apply_chain(n, x);
188
170
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
189
- index->search (n, xt, k, distances, labels);
171
+ index->search(n, xt, k, distances, labels);
190
172
  }
191
173
 
192
- void IndexPreTransform::range_search (idx_t n, const float* x, float radius,
193
- RangeSearchResult* result) const
194
- {
195
- FAISS_THROW_IF_NOT (is_trained);
196
- const float *xt = apply_chain (n, x);
174
+ void IndexPreTransform::range_search(
175
+ idx_t n,
176
+ const float* x,
177
+ float radius,
178
+ RangeSearchResult* result) const {
179
+ FAISS_THROW_IF_NOT(is_trained);
180
+ const float* xt = apply_chain(n, x);
197
181
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
198
- index->range_search (n, xt, radius, result);
182
+ index->range_search(n, xt, radius, result);
199
183
  }
200
184
 
201
-
202
-
203
- void IndexPreTransform::reset () {
185
+ void IndexPreTransform::reset() {
204
186
  index->reset();
205
187
  ntotal = 0;
206
188
  }
207
189
 
208
- size_t IndexPreTransform::remove_ids (const IDSelector & sel) {
209
- size_t nremove = index->remove_ids (sel);
190
+ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
191
+ size_t nremove = index->remove_ids(sel);
210
192
  ntotal = index->ntotal;
211
193
  return nremove;
212
194
  }
213
195
 
214
-
215
- void IndexPreTransform::reconstruct (idx_t key, float * recons) const
216
- {
217
- float *x = chain.empty() ? recons : new float [index->d];
218
- ScopeDeleter<float> del (recons == x ? nullptr : x);
196
+ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
197
+ float* x = chain.empty() ? recons : new float[index->d];
198
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
219
199
  // Initial reconstruction
220
- index->reconstruct (key, x);
200
+ index->reconstruct(key, x);
221
201
 
222
202
  // Revert transformations from last to first
223
- reverse_chain (1, x, recons);
203
+ reverse_chain(1, x, recons);
224
204
  }
225
205
 
226
-
227
- void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
228
- {
229
- float *x = chain.empty() ? recons : new float [ni * index->d];
230
- ScopeDeleter<float> del (recons == x ? nullptr : x);
206
+ void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
207
+ float* x = chain.empty() ? recons : new float[ni * index->d];
208
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
231
209
  // Initial reconstruction
232
- index->reconstruct_n (i0, ni, x);
210
+ index->reconstruct_n(i0, ni, x);
233
211
 
234
212
  // Revert transformations from last to first
235
- reverse_chain (ni, x, recons);
213
+ reverse_chain(ni, x, recons);
236
214
  }
237
215
 
216
+ void IndexPreTransform::search_and_reconstruct(
217
+ idx_t n,
218
+ const float* x,
219
+ idx_t k,
220
+ float* distances,
221
+ idx_t* labels,
222
+ float* recons) const {
223
+ FAISS_THROW_IF_NOT(k > 0);
238
224
 
239
- void IndexPreTransform::search_and_reconstruct (
240
- idx_t n, const float *x, idx_t k,
241
- float *distances, idx_t *labels, float* recons) const
242
- {
243
- FAISS_THROW_IF_NOT (is_trained);
225
+ FAISS_THROW_IF_NOT(is_trained);
244
226
 
245
- const float* xt = apply_chain (n, x);
246
- ScopeDeleter<float> del ((xt == x) ? nullptr : xt);
227
+ const float* xt = apply_chain(n, x);
228
+ ScopeDeleter<float> del((xt == x) ? nullptr : xt);
247
229
 
248
- float* recons_temp = chain.empty() ? recons : new float [n * k * index->d];
249
- ScopeDeleter<float> del2 ((recons_temp == recons) ? nullptr : recons_temp);
250
- index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp);
230
+ float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
231
+ ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
232
+ index->search_and_reconstruct(n, xt, k, distances, labels, recons_temp);
251
233
 
252
234
  // Revert transformations from last to first
253
- reverse_chain (n * k, recons_temp, recons);
235
+ reverse_chain(n * k, recons_temp, recons);
254
236
  }
255
237
 
256
- size_t IndexPreTransform::sa_code_size () const
257
- {
258
- return index->sa_code_size ();
238
+ size_t IndexPreTransform::sa_code_size() const {
239
+ return index->sa_code_size();
259
240
  }
260
241
 
261
- void IndexPreTransform::sa_encode (idx_t n, const float *x,
262
- uint8_t *bytes) const
263
- {
242
+ void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
243
+ const {
264
244
  if (chain.empty()) {
265
- index->sa_encode (n, x, bytes);
245
+ index->sa_encode(n, x, bytes);
266
246
  } else {
267
- const float *xt = apply_chain (n, x);
247
+ const float* xt = apply_chain(n, x);
268
248
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
269
- index->sa_encode (n, xt, bytes);
249
+ index->sa_encode(n, xt, bytes);
270
250
  }
271
251
  }
272
252
 
273
- void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
274
- float *x) const
275
- {
253
+ void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
254
+ const {
276
255
  if (chain.empty()) {
277
- index->sa_decode (n, bytes, x);
256
+ index->sa_decode(n, bytes, x);
278
257
  } else {
279
- std::unique_ptr<float []> x1 (new float [index->d * n]);
280
- index->sa_decode (n, bytes, x1.get());
258
+ std::unique_ptr<float[]> x1(new float[index->d * n]);
259
+ index->sa_decode(n, bytes, x1.get());
281
260
  // Revert transformations from last to first
282
- reverse_chain (n, x1.get(), x);
261
+ reverse_chain(n, x1.get(), x);
283
262
  }
284
263
  }
285
264
 
286
265
  namespace {
287
266
 
288
- struct PreTransformDistanceComputer: DistanceComputer {
289
- const IndexPreTransform *index;
267
+ struct PreTransformDistanceComputer : DistanceComputer {
268
+ const IndexPreTransform* index;
290
269
  std::unique_ptr<DistanceComputer> sub_dc;
291
- std::unique_ptr<const float []> query;
270
+ std::unique_ptr<const float[]> query;
292
271
 
293
- explicit PreTransformDistanceComputer(const IndexPreTransform *index):
294
- index(index),
295
- sub_dc(index->index->get_distance_computer())
296
- {}
272
+ explicit PreTransformDistanceComputer(const IndexPreTransform* index)
273
+ : index(index), sub_dc(index->index->get_distance_computer()) {}
297
274
 
298
- void set_query(const float *x) override {
299
- const float *xt = index->apply_chain (1, x);
275
+ void set_query(const float* x) override {
276
+ const float* xt = index->apply_chain(1, x);
300
277
  if (xt == x) {
301
- sub_dc->set_query (x);
278
+ sub_dc->set_query(x);
302
279
  } else {
303
280
  query.reset(xt);
304
- sub_dc->set_query (xt);
281
+ sub_dc->set_query(xt);
305
282
  }
306
283
  }
307
284
 
308
- float symmetric_dis(idx_t i, idx_t j) override
309
- {
285
+ float symmetric_dis(idx_t i, idx_t j) override {
310
286
  return sub_dc->symmetric_dis(i, j);
311
287
  }
312
288
 
313
- float operator () (idx_t i) override
314
- {
289
+ float operator()(idx_t i) override {
315
290
  return (*sub_dc)(i);
316
291
  }
317
-
318
292
  };
319
293
 
320
-
321
294
  } // anonymous namespace
322
295
 
323
-
324
- DistanceComputer * IndexPreTransform::get_distance_computer() const {
296
+ DistanceComputer* IndexPreTransform::get_distance_computer() const {
325
297
  if (chain.empty()) {
326
298
  return index->get_distance_computer();
327
299
  } else {
@@ -329,6 +301,4 @@ DistanceComputer * IndexPreTransform::get_distance_computer() const {
329
301
  }
330
302
  }
331
303
 
332
-
333
-
334
304
  } // namespace faiss