faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -5,17 +5,14 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <faiss/IndexPQ.h>
12
11
  #include <faiss/impl/ProductQuantizer.h>
13
12
  #include <faiss/utils/AlignedTable.h>
14
13
 
15
-
16
14
  namespace faiss {
17
15
 
18
-
19
16
  /** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
20
17
  *
21
18
  * The codes are not stored sequentially but grouped in blocks of size bbs.
@@ -28,7 +25,7 @@ namespace faiss {
28
25
  * 15: no qbs with reservoir accumulator
29
26
  */
30
27
 
31
- struct IndexPQFastScan: Index {
28
+ struct IndexPQFastScan : Index {
32
29
  ProductQuantizer pq;
33
30
 
34
31
  // implementation to select
@@ -37,8 +34,8 @@ struct IndexPQFastScan: Index {
37
34
  int skip = 0;
38
35
 
39
36
  // size of the kernel
40
- int bbs; // set at build time
41
- int qbs = 0; // query block size 0 = use default
37
+ int bbs; // set at build time
38
+ int qbs = 0; // query block size 0 = use default
42
39
 
43
40
  // packed version of the codes
44
41
  size_t ntotal2;
@@ -47,22 +44,23 @@ struct IndexPQFastScan: Index {
47
44
  AlignedTable<uint8_t> codes;
48
45
 
49
46
  // this is for testing purposes only (set when initialized by IndexPQ)
50
- const uint8_t *orig_codes = nullptr;
47
+ const uint8_t* orig_codes = nullptr;
51
48
 
52
49
  IndexPQFastScan(
53
- int d, size_t M, size_t nbits,
54
- MetricType metric = METRIC_L2,
55
- int bbs = 32
56
- );
50
+ int d,
51
+ size_t M,
52
+ size_t nbits,
53
+ MetricType metric = METRIC_L2,
54
+ int bbs = 32);
57
55
 
58
56
  IndexPQFastScan();
59
57
 
60
58
  /// build from an existing IndexPQ
61
- explicit IndexPQFastScan(const IndexPQ & orig, int bbs = 32);
59
+ explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
62
60
 
63
- void train (idx_t n, const float *x) override;
64
- void add (idx_t n, const float *x) override;
65
- void reset() override ;
61
+ void train(idx_t n, const float* x) override;
62
+ void add(idx_t n, const float* x) override;
63
+ void reset() override;
66
64
  void search(
67
65
  idx_t n,
68
66
  const float* x,
@@ -72,35 +70,51 @@ struct IndexPQFastScan: Index {
72
70
 
73
71
  // called by search function
74
72
  void compute_quantized_LUT(
75
- idx_t n, const float* x,
76
- uint8_t *lut, float *normalizers) const ;
73
+ idx_t n,
74
+ const float* x,
75
+ uint8_t* lut,
76
+ float* normalizers) const;
77
77
 
78
- template<bool is_max>
78
+ template <bool is_max>
79
79
  void search_dispatch_implem(
80
- idx_t n, const float* x, idx_t k,
81
- float* distances, idx_t* labels) const;
80
+ idx_t n,
81
+ const float* x,
82
+ idx_t k,
83
+ float* distances,
84
+ idx_t* labels) const;
82
85
 
83
- template<class C>
86
+ template <class C>
84
87
  void search_implem_2(
85
- idx_t n, const float* x, idx_t k,
86
- float* distances, idx_t* labels) const;
87
-
88
+ idx_t n,
89
+ const float* x,
90
+ idx_t k,
91
+ float* distances,
92
+ idx_t* labels) const;
88
93
 
89
- template<class C>
94
+ template <class C>
90
95
  void search_implem_12(
91
- idx_t n, const float* x, idx_t k,
92
- float* distances, idx_t* labels, int impl) const;
96
+ idx_t n,
97
+ const float* x,
98
+ idx_t k,
99
+ float* distances,
100
+ idx_t* labels,
101
+ int impl) const;
93
102
 
94
- template<class C>
103
+ template <class C>
95
104
  void search_implem_14(
96
- idx_t n, const float* x, idx_t k,
97
- float* distances, idx_t* labels, int impl) const;
98
-
105
+ idx_t n,
106
+ const float* x,
107
+ idx_t k,
108
+ float* distances,
109
+ idx_t* labels,
110
+ int impl) const;
99
111
  };
100
112
 
101
113
  struct FastScanStats {
102
114
  uint64_t t0, t1, t2, t3;
103
- FastScanStats() {reset();}
115
+ FastScanStats() {
116
+ reset();
117
+ }
104
118
  void reset() {
105
119
  memset(this, 0, sizeof(*this));
106
120
  }
@@ -9,13 +9,13 @@
9
9
 
10
10
  #include <faiss/IndexPreTransform.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <cmath>
13
+ #include <cstdio>
14
14
  #include <cstring>
15
15
  #include <memory>
16
16
 
17
- #include <faiss/impl/FaissAssert.h>
18
17
  #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
19
 
20
20
  namespace faiss {
21
21
 
@@ -23,44 +23,29 @@ namespace faiss {
23
23
  * IndexPreTransform
24
24
  *********************************************/
25
25
 
26
- IndexPreTransform::IndexPreTransform ():
27
- index(nullptr), own_fields (false)
28
- {
29
- }
30
-
26
+ IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {}
31
27
 
32
- IndexPreTransform::IndexPreTransform (
33
- Index * index):
34
- Index (index->d, index->metric_type),
35
- index (index), own_fields (false)
36
- {
28
+ IndexPreTransform::IndexPreTransform(Index* index)
29
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
37
30
  is_trained = index->is_trained;
38
31
  ntotal = index->ntotal;
39
32
  }
40
33
 
41
-
42
- IndexPreTransform::IndexPreTransform (
43
- VectorTransform * ltrans,
44
- Index * index):
45
- Index (index->d, index->metric_type),
46
- index (index), own_fields (false)
47
- {
34
+ IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index)
35
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
48
36
  is_trained = index->is_trained;
49
37
  ntotal = index->ntotal;
50
- prepend_transform (ltrans);
38
+ prepend_transform(ltrans);
51
39
  }
52
40
 
53
- void IndexPreTransform::prepend_transform (VectorTransform *ltrans)
54
- {
55
- FAISS_THROW_IF_NOT (ltrans->d_out == d);
41
+ void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
42
+ FAISS_THROW_IF_NOT(ltrans->d_out == d);
56
43
  is_trained = is_trained && ltrans->is_trained;
57
- chain.insert (chain.begin(), ltrans);
44
+ chain.insert(chain.begin(), ltrans);
58
45
  d = ltrans->d_in;
59
46
  }
60
47
 
61
-
62
- IndexPreTransform::~IndexPreTransform ()
63
- {
48
+ IndexPreTransform::~IndexPreTransform() {
64
49
  if (own_fields) {
65
50
  for (int i = 0; i < chain.size(); i++)
66
51
  delete chain[i];
@@ -68,11 +53,7 @@ IndexPreTransform::~IndexPreTransform ()
68
53
  }
69
54
  }
70
55
 
71
-
72
-
73
-
74
- void IndexPreTransform::train (idx_t n, const float *x)
75
- {
56
+ void IndexPreTransform::train(idx_t n, const float* x) {
76
57
  int last_untrained = 0;
77
58
  if (!index->is_trained) {
78
59
  last_untrained = chain.size();
@@ -84,7 +65,7 @@ void IndexPreTransform::train (idx_t n, const float *x)
84
65
  }
85
66
  }
86
67
  }
87
- const float *prev_x = x;
68
+ const float* prev_x = x;
88
69
  ScopeDeleter<float> del;
89
70
 
90
71
  if (verbose) {
@@ -93,34 +74,35 @@ void IndexPreTransform::train (idx_t n, const float *x)
93
74
  }
94
75
 
95
76
  for (int i = 0; i <= last_untrained; i++) {
96
-
97
77
  if (i < chain.size()) {
98
- VectorTransform *ltrans = chain [i];
78
+ VectorTransform* ltrans = chain[i];
99
79
  if (!ltrans->is_trained) {
100
80
  if (verbose) {
101
81
  printf(" Training chain component %d/%zd\n",
102
- i, chain.size());
103
- if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
82
+ i,
83
+ chain.size());
84
+ if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
104
85
  opqm->verbose = true;
105
86
  }
106
87
  }
107
- ltrans->train (n, prev_x);
88
+ ltrans->train(n, prev_x);
108
89
  }
109
90
  } else {
110
91
  if (verbose) {
111
92
  printf(" Training sub-index\n");
112
93
  }
113
- index->train (n, prev_x);
94
+ index->train(n, prev_x);
114
95
  }
115
- if (i == last_untrained) break;
96
+ if (i == last_untrained)
97
+ break;
116
98
  if (verbose) {
117
- printf(" Applying transform %d/%zd\n",
118
- i, chain.size());
99
+ printf(" Applying transform %d/%zd\n", i, chain.size());
119
100
  }
120
101
 
121
- float * xt = chain[i]->apply (n, prev_x);
102
+ float* xt = chain[i]->apply(n, prev_x);
122
103
 
123
- if (prev_x != x) delete [] prev_x;
104
+ if (prev_x != x)
105
+ delete[] prev_x;
124
106
  prev_x = xt;
125
107
  del.set(xt);
126
108
  }
@@ -128,200 +110,190 @@ void IndexPreTransform::train (idx_t n, const float *x)
128
110
  is_trained = true;
129
111
  }
130
112
 
131
-
132
- const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
133
- {
134
- const float *prev_x = x;
113
+ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
114
+ const float* prev_x = x;
135
115
  ScopeDeleter<float> del;
136
116
 
137
117
  for (int i = 0; i < chain.size(); i++) {
138
- float * xt = chain[i]->apply (n, prev_x);
139
- ScopeDeleter<float> del2 (xt);
140
- del2.swap (del);
118
+ float* xt = chain[i]->apply(n, prev_x);
119
+ ScopeDeleter<float> del2(xt);
120
+ del2.swap(del);
141
121
  prev_x = xt;
142
122
  }
143
- del.release ();
123
+ del.release();
144
124
  return prev_x;
145
125
  }
146
126
 
147
- void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const
148
- {
127
+ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
128
+ const {
149
129
  const float* next_x = xt;
150
130
  ScopeDeleter<float> del;
151
131
 
152
132
  for (int i = chain.size() - 1; i >= 0; i--) {
153
- float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in];
154
- ScopeDeleter<float> del2 ((prev_x == x) ? nullptr : prev_x);
155
- chain [i]->reverse_transform (n, next_x, prev_x);
156
- del2.swap (del);
133
+ float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
134
+ ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
135
+ chain[i]->reverse_transform(n, next_x, prev_x);
136
+ del2.swap(del);
157
137
  next_x = prev_x;
158
138
  }
159
139
  }
160
140
 
161
- void IndexPreTransform::add (idx_t n, const float *x)
162
- {
163
- FAISS_THROW_IF_NOT (is_trained);
164
- const float *xt = apply_chain (n, x);
141
+ void IndexPreTransform::add(idx_t n, const float* x) {
142
+ FAISS_THROW_IF_NOT(is_trained);
143
+ const float* xt = apply_chain(n, x);
165
144
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
166
- index->add (n, xt);
145
+ index->add(n, xt);
167
146
  ntotal = index->ntotal;
168
147
  }
169
148
 
170
- void IndexPreTransform::add_with_ids (idx_t n, const float * x,
171
- const idx_t *xids)
172
- {
173
- FAISS_THROW_IF_NOT (is_trained);
174
- const float *xt = apply_chain (n, x);
149
+ void IndexPreTransform::add_with_ids(
150
+ idx_t n,
151
+ const float* x,
152
+ const idx_t* xids) {
153
+ FAISS_THROW_IF_NOT(is_trained);
154
+ const float* xt = apply_chain(n, x);
175
155
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
176
- index->add_with_ids (n, xt, xids);
156
+ index->add_with_ids(n, xt, xids);
177
157
  ntotal = index->ntotal;
178
158
  }
179
159
 
160
+ void IndexPreTransform::search(
161
+ idx_t n,
162
+ const float* x,
163
+ idx_t k,
164
+ float* distances,
165
+ idx_t* labels) const {
166
+ FAISS_THROW_IF_NOT(k > 0);
180
167
 
181
-
182
-
183
- void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
184
- float *distances, idx_t *labels) const
185
- {
186
- FAISS_THROW_IF_NOT (is_trained);
187
- const float *xt = apply_chain (n, x);
168
+ FAISS_THROW_IF_NOT(is_trained);
169
+ const float* xt = apply_chain(n, x);
188
170
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
189
- index->search (n, xt, k, distances, labels);
171
+ index->search(n, xt, k, distances, labels);
190
172
  }
191
173
 
192
- void IndexPreTransform::range_search (idx_t n, const float* x, float radius,
193
- RangeSearchResult* result) const
194
- {
195
- FAISS_THROW_IF_NOT (is_trained);
196
- const float *xt = apply_chain (n, x);
174
+ void IndexPreTransform::range_search(
175
+ idx_t n,
176
+ const float* x,
177
+ float radius,
178
+ RangeSearchResult* result) const {
179
+ FAISS_THROW_IF_NOT(is_trained);
180
+ const float* xt = apply_chain(n, x);
197
181
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
198
- index->range_search (n, xt, radius, result);
182
+ index->range_search(n, xt, radius, result);
199
183
  }
200
184
 
201
-
202
-
203
- void IndexPreTransform::reset () {
185
+ void IndexPreTransform::reset() {
204
186
  index->reset();
205
187
  ntotal = 0;
206
188
  }
207
189
 
208
- size_t IndexPreTransform::remove_ids (const IDSelector & sel) {
209
- size_t nremove = index->remove_ids (sel);
190
+ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
191
+ size_t nremove = index->remove_ids(sel);
210
192
  ntotal = index->ntotal;
211
193
  return nremove;
212
194
  }
213
195
 
214
-
215
- void IndexPreTransform::reconstruct (idx_t key, float * recons) const
216
- {
217
- float *x = chain.empty() ? recons : new float [index->d];
218
- ScopeDeleter<float> del (recons == x ? nullptr : x);
196
+ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
197
+ float* x = chain.empty() ? recons : new float[index->d];
198
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
219
199
  // Initial reconstruction
220
- index->reconstruct (key, x);
200
+ index->reconstruct(key, x);
221
201
 
222
202
  // Revert transformations from last to first
223
- reverse_chain (1, x, recons);
203
+ reverse_chain(1, x, recons);
224
204
  }
225
205
 
226
-
227
- void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
228
- {
229
- float *x = chain.empty() ? recons : new float [ni * index->d];
230
- ScopeDeleter<float> del (recons == x ? nullptr : x);
206
+ void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
207
+ float* x = chain.empty() ? recons : new float[ni * index->d];
208
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
231
209
  // Initial reconstruction
232
- index->reconstruct_n (i0, ni, x);
210
+ index->reconstruct_n(i0, ni, x);
233
211
 
234
212
  // Revert transformations from last to first
235
- reverse_chain (ni, x, recons);
213
+ reverse_chain(ni, x, recons);
236
214
  }
237
215
 
216
+ void IndexPreTransform::search_and_reconstruct(
217
+ idx_t n,
218
+ const float* x,
219
+ idx_t k,
220
+ float* distances,
221
+ idx_t* labels,
222
+ float* recons) const {
223
+ FAISS_THROW_IF_NOT(k > 0);
238
224
 
239
- void IndexPreTransform::search_and_reconstruct (
240
- idx_t n, const float *x, idx_t k,
241
- float *distances, idx_t *labels, float* recons) const
242
- {
243
- FAISS_THROW_IF_NOT (is_trained);
225
+ FAISS_THROW_IF_NOT(is_trained);
244
226
 
245
- const float* xt = apply_chain (n, x);
246
- ScopeDeleter<float> del ((xt == x) ? nullptr : xt);
227
+ const float* xt = apply_chain(n, x);
228
+ ScopeDeleter<float> del((xt == x) ? nullptr : xt);
247
229
 
248
- float* recons_temp = chain.empty() ? recons : new float [n * k * index->d];
249
- ScopeDeleter<float> del2 ((recons_temp == recons) ? nullptr : recons_temp);
250
- index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp);
230
+ float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
231
+ ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
232
+ index->search_and_reconstruct(n, xt, k, distances, labels, recons_temp);
251
233
 
252
234
  // Revert transformations from last to first
253
- reverse_chain (n * k, recons_temp, recons);
235
+ reverse_chain(n * k, recons_temp, recons);
254
236
  }
255
237
 
256
- size_t IndexPreTransform::sa_code_size () const
257
- {
258
- return index->sa_code_size ();
238
+ size_t IndexPreTransform::sa_code_size() const {
239
+ return index->sa_code_size();
259
240
  }
260
241
 
261
- void IndexPreTransform::sa_encode (idx_t n, const float *x,
262
- uint8_t *bytes) const
263
- {
242
+ void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
243
+ const {
264
244
  if (chain.empty()) {
265
- index->sa_encode (n, x, bytes);
245
+ index->sa_encode(n, x, bytes);
266
246
  } else {
267
- const float *xt = apply_chain (n, x);
247
+ const float* xt = apply_chain(n, x);
268
248
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
269
- index->sa_encode (n, xt, bytes);
249
+ index->sa_encode(n, xt, bytes);
270
250
  }
271
251
  }
272
252
 
273
- void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
274
- float *x) const
275
- {
253
+ void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
254
+ const {
276
255
  if (chain.empty()) {
277
- index->sa_decode (n, bytes, x);
256
+ index->sa_decode(n, bytes, x);
278
257
  } else {
279
- std::unique_ptr<float []> x1 (new float [index->d * n]);
280
- index->sa_decode (n, bytes, x1.get());
258
+ std::unique_ptr<float[]> x1(new float[index->d * n]);
259
+ index->sa_decode(n, bytes, x1.get());
281
260
  // Revert transformations from last to first
282
- reverse_chain (n, x1.get(), x);
261
+ reverse_chain(n, x1.get(), x);
283
262
  }
284
263
  }
285
264
 
286
265
  namespace {
287
266
 
288
- struct PreTransformDistanceComputer: DistanceComputer {
289
- const IndexPreTransform *index;
267
+ struct PreTransformDistanceComputer : DistanceComputer {
268
+ const IndexPreTransform* index;
290
269
  std::unique_ptr<DistanceComputer> sub_dc;
291
- std::unique_ptr<const float []> query;
270
+ std::unique_ptr<const float[]> query;
292
271
 
293
- explicit PreTransformDistanceComputer(const IndexPreTransform *index):
294
- index(index),
295
- sub_dc(index->index->get_distance_computer())
296
- {}
272
+ explicit PreTransformDistanceComputer(const IndexPreTransform* index)
273
+ : index(index), sub_dc(index->index->get_distance_computer()) {}
297
274
 
298
- void set_query(const float *x) override {
299
- const float *xt = index->apply_chain (1, x);
275
+ void set_query(const float* x) override {
276
+ const float* xt = index->apply_chain(1, x);
300
277
  if (xt == x) {
301
- sub_dc->set_query (x);
278
+ sub_dc->set_query(x);
302
279
  } else {
303
280
  query.reset(xt);
304
- sub_dc->set_query (xt);
281
+ sub_dc->set_query(xt);
305
282
  }
306
283
  }
307
284
 
308
- float symmetric_dis(idx_t i, idx_t j) override
309
- {
285
+ float symmetric_dis(idx_t i, idx_t j) override {
310
286
  return sub_dc->symmetric_dis(i, j);
311
287
  }
312
288
 
313
- float operator () (idx_t i) override
314
- {
289
+ float operator()(idx_t i) override {
315
290
  return (*sub_dc)(i);
316
291
  }
317
-
318
292
  };
319
293
 
320
-
321
294
  } // anonymous namespace
322
295
 
323
-
324
- DistanceComputer * IndexPreTransform::get_distance_computer() const {
296
+ DistanceComputer* IndexPreTransform::get_distance_computer() const {
325
297
  if (chain.empty()) {
326
298
  return index->get_distance_computer();
327
299
  } else {
@@ -329,6 +301,4 @@ DistanceComputer * IndexPreTransform::get_distance_computer() const {
329
301
  }
330
302
  }
331
303
 
332
-
333
-
334
304
  } // namespace faiss