faiss 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +38 -28
- data/ext/faiss/{index.cpp → index_rb.cpp} +64 -46
- data/ext/faiss/kmeans.cpp +10 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +5 -3
- data/ext/faiss/{utils.h → utils_rb.h} +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 18db7914c6db421beb972845cdf1d6489f179d6347f2cc52252b5908c7dd2db0
|
|
4
|
+
data.tar.gz: fcbc66c8b544a9f96e913c4e4e3d7a8be8c450d56db0332bc3ea1810776e424b
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: a885807f3c74cb4ce27a5931f7f2b73dc3ec9bd214c06e96d191086a841d1f366e1d8ebbe9af35143c8facd1285070d0f0b46814f5a78d1d55fb0b1821ab3fa3
|
|
7
|
+
data.tar.gz: f7ec06da8dc75e0341763f15501e2bc503ef3001dee29c88a0b0253ce78c38cb5e5b222ec75b26814a7c9f2a69e02e1ff685f108d3a4512ef8bf55e5a42214ce
|
data/CHANGELOG.md
CHANGED
data/ext/faiss/ext.cpp
CHANGED
data/ext/faiss/extconf.rb
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
require "mkmf-rice"
|
|
2
|
-
require "numo/narray"
|
|
2
|
+
require "numo/narray/alt"
|
|
3
3
|
|
|
4
4
|
# libomp changed to keg-only
|
|
5
5
|
# https://github.com/Homebrew/homebrew-core/issues/112107
|
|
@@ -13,14 +13,13 @@ abort "BLAS not found" unless have_library("blas")
|
|
|
13
13
|
abort "LAPACK not found" unless have_library("lapack")
|
|
14
14
|
abort "OpenMP not found" unless have_library("omp") || have_library("gomp")
|
|
15
15
|
|
|
16
|
-
numo = File.join(Gem.loaded_specs["numo-narray"].require_path, "numo")
|
|
16
|
+
numo = File.join(Gem.loaded_specs["numo-narray-alt"].require_path, "numo")
|
|
17
17
|
abort "Numo not found" unless find_header("numo/narray.h", numo)
|
|
18
18
|
|
|
19
19
|
# for https://bugs.ruby-lang.org/issues/19005
|
|
20
20
|
$LDFLAGS += " -Wl,-undefined,dynamic_lookup" if RbConfig::CONFIG["host_os"] =~ /darwin/i
|
|
21
21
|
|
|
22
|
-
$CXXFLAGS += " -std=c++
|
|
23
|
-
$CXXFLAGS += " -Wall -Wno-unused-parameter -Wno-unused-function -Wno-unused-variable -Wno-unused-private-field -Wno-deprecated-declarations -Wno-sign-compare"
|
|
22
|
+
$CXXFLAGS += " -std=c++20 $(optflags) -DFINTEGER=int"
|
|
24
23
|
|
|
25
24
|
# -march=native not supported with ARM Mac
|
|
26
25
|
default_optflags = RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i ? "" : " -march=native"
|
|
@@ -34,8 +33,8 @@ ext = File.expand_path(".", __dir__)
|
|
|
34
33
|
vendor = File.expand_path("../../vendor/faiss", __dir__)
|
|
35
34
|
|
|
36
35
|
$srcs = Dir["{#{ext},#{vendor}/faiss,#{vendor}/faiss/{impl,invlists,utils}/**}/*.{cpp}"]
|
|
37
|
-
$
|
|
36
|
+
$srcs -= ["avx2", "avx512", "aarch64", "arm_sve"].map { |v| "#{vendor}/faiss/utils/simd_impl/distances_#{v}.cpp" }
|
|
38
37
|
abort "Faiss not found" unless find_header("faiss/Index.h", vendor)
|
|
39
|
-
$VPATH
|
|
38
|
+
$VPATH += $srcs.filter_map { |v| File.dirname(v) if v.start_with?(vendor) }.uniq
|
|
40
39
|
|
|
41
40
|
create_makefile("faiss/ext")
|
data/ext/faiss/index_binary.cpp
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
#include <algorithm>
|
|
2
|
+
#include <cstddef>
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
#include <limits>
|
|
5
|
+
#include <utility>
|
|
2
6
|
#include <vector>
|
|
3
7
|
|
|
4
8
|
#include <faiss/IndexBinary.h>
|
|
@@ -9,23 +13,23 @@
|
|
|
9
13
|
#include <rice/rice.hpp>
|
|
10
14
|
|
|
11
15
|
#include "numo.hpp"
|
|
12
|
-
#include "
|
|
16
|
+
#include "utils_rb.h"
|
|
13
17
|
|
|
14
18
|
void init_index_binary(Rice::Module& m) {
|
|
15
19
|
Rice::define_class_under<faiss::IndexBinary>(m, "IndexBinary")
|
|
16
20
|
.define_method(
|
|
17
21
|
"d",
|
|
18
|
-
[](faiss::IndexBinary
|
|
22
|
+
[](faiss::IndexBinary& self) {
|
|
19
23
|
return self.d;
|
|
20
24
|
})
|
|
21
25
|
.define_method(
|
|
22
26
|
"trained?",
|
|
23
|
-
[](faiss::IndexBinary
|
|
27
|
+
[](faiss::IndexBinary& self) {
|
|
24
28
|
return self.is_trained;
|
|
25
29
|
})
|
|
26
30
|
.define_method(
|
|
27
31
|
"ntotal",
|
|
28
|
-
[](faiss::IndexBinary
|
|
32
|
+
[](faiss::IndexBinary& self) {
|
|
29
33
|
return self.ntotal;
|
|
30
34
|
})
|
|
31
35
|
.define_method(
|
|
@@ -33,8 +37,8 @@ void init_index_binary(Rice::Module& m) {
|
|
|
33
37
|
[](Rice::Object rb_self, numo::UInt8 objects) {
|
|
34
38
|
check_frozen(rb_self);
|
|
35
39
|
|
|
36
|
-
auto
|
|
37
|
-
|
|
40
|
+
auto& self = *Rice::Data_Object<faiss::IndexBinary>{rb_self};
|
|
41
|
+
size_t n = check_shape(objects, self.d / 8);
|
|
38
42
|
self.train(n, objects.read_ptr());
|
|
39
43
|
})
|
|
40
44
|
.define_method(
|
|
@@ -42,8 +46,8 @@ void init_index_binary(Rice::Module& m) {
|
|
|
42
46
|
[](Rice::Object rb_self, numo::UInt8 objects) {
|
|
43
47
|
check_frozen(rb_self);
|
|
44
48
|
|
|
45
|
-
auto
|
|
46
|
-
|
|
49
|
+
auto& self = *Rice::Data_Object<faiss::IndexBinary>{rb_self};
|
|
50
|
+
size_t n = check_shape(objects, self.d / 8);
|
|
47
51
|
self.add(n, objects.read_ptr());
|
|
48
52
|
})
|
|
49
53
|
.define_method(
|
|
@@ -51,36 +55,42 @@ void init_index_binary(Rice::Module& m) {
|
|
|
51
55
|
[](Rice::Object rb_self, numo::Int64 ids) {
|
|
52
56
|
check_frozen(rb_self);
|
|
53
57
|
|
|
54
|
-
auto
|
|
58
|
+
auto& self = *Rice::Data_Object<faiss::IndexBinary>{rb_self};
|
|
55
59
|
if (ids.ndim() != 1) {
|
|
56
60
|
throw Rice::Exception(rb_eArgError, "expected ids to be 1d array");
|
|
57
61
|
}
|
|
58
|
-
|
|
62
|
+
size_t n = ids.shape()[0];
|
|
59
63
|
faiss::IDSelectorBatch sel(n, ids.read_ptr());
|
|
60
64
|
return self.remove_ids(sel);
|
|
61
65
|
})
|
|
62
66
|
.define_method(
|
|
63
67
|
"search",
|
|
64
|
-
[](Rice::Object rb_self, numo::UInt8 objects,
|
|
65
|
-
auto
|
|
66
|
-
|
|
68
|
+
[](Rice::Object rb_self, numo::UInt8 objects, int64_t k) {
|
|
69
|
+
auto& self = *Rice::Data_Object<faiss::IndexBinary>{rb_self};
|
|
70
|
+
size_t n = check_shape(objects, self.d / 8);
|
|
71
|
+
if (k <= 0) {
|
|
72
|
+
throw Rice::Exception(rb_eArgError, "expected k to be positive");
|
|
73
|
+
}
|
|
74
|
+
if (std::cmp_greater_equal(k, std::numeric_limits<size_t>::max() / n)) {
|
|
75
|
+
throw Rice::Exception(rb_eArgError, "k too large");
|
|
76
|
+
}
|
|
67
77
|
|
|
68
|
-
|
|
69
|
-
|
|
78
|
+
numo::Int32 distances({n, static_cast<size_t>(k)});
|
|
79
|
+
numo::Int64 labels({n, static_cast<size_t>(k)});
|
|
70
80
|
|
|
71
81
|
if (rb_self.is_frozen()) {
|
|
72
82
|
// Don't mess with Ruby-owned memory while the GVL is released
|
|
73
|
-
auto objects_ptr = objects.read_ptr();
|
|
83
|
+
const auto* objects_ptr = objects.read_ptr();
|
|
74
84
|
std::vector<uint8_t> objects_vec(objects_ptr, objects_ptr + n * (self.d / 8));
|
|
75
|
-
std::vector<int32_t> distances_vec(n * k);
|
|
76
|
-
std::vector<int64_t> labels_vec(n * k);
|
|
85
|
+
std::vector<int32_t> distances_vec(n * static_cast<size_t>(k));
|
|
86
|
+
std::vector<int64_t> labels_vec(n * static_cast<size_t>(k));
|
|
77
87
|
|
|
78
88
|
Rice::detail::no_gvl([&] {
|
|
79
89
|
self.search(n, objects_vec.data(), k, distances_vec.data(), labels_vec.data());
|
|
80
90
|
});
|
|
81
91
|
|
|
82
|
-
std::copy(distances_vec
|
|
83
|
-
std::copy(labels_vec
|
|
92
|
+
std::ranges::copy(distances_vec, distances.write_ptr());
|
|
93
|
+
std::ranges::copy(labels_vec, labels.write_ptr());
|
|
84
94
|
} else {
|
|
85
95
|
self.search(n, objects.read_ptr(), k, distances.write_ptr(), labels.write_ptr());
|
|
86
96
|
}
|
|
@@ -92,15 +102,15 @@ void init_index_binary(Rice::Module& m) {
|
|
|
92
102
|
})
|
|
93
103
|
.define_method(
|
|
94
104
|
"reconstruct",
|
|
95
|
-
[](faiss::IndexBinary
|
|
96
|
-
auto d = static_cast<
|
|
97
|
-
|
|
105
|
+
[](faiss::IndexBinary& self, int64_t key) {
|
|
106
|
+
auto d = static_cast<size_t>(self.d / 8);
|
|
107
|
+
numo::UInt8 recons({d});
|
|
98
108
|
self.reconstruct(key, recons.write_ptr());
|
|
99
109
|
return recons;
|
|
100
110
|
})
|
|
101
111
|
.define_method(
|
|
102
112
|
"reconstruct_n",
|
|
103
|
-
[](faiss::IndexBinary
|
|
113
|
+
[](faiss::IndexBinary& self, int64_t i0, int64_t ni) {
|
|
104
114
|
if (ni < 0) {
|
|
105
115
|
throw Rice::Exception(rb_eArgError, "expected n to be non-negative");
|
|
106
116
|
}
|
|
@@ -108,15 +118,15 @@ void init_index_binary(Rice::Module& m) {
|
|
|
108
118
|
if (i0 < 0 || i0 > self.ntotal - ni) {
|
|
109
119
|
throw Rice::Exception(rb_eIndexError, "index out of range");
|
|
110
120
|
}
|
|
111
|
-
auto d = static_cast<
|
|
112
|
-
auto n = static_cast<
|
|
113
|
-
|
|
121
|
+
auto d = static_cast<size_t>(self.d / 8);
|
|
122
|
+
auto n = static_cast<size_t>(ni);
|
|
123
|
+
numo::UInt8 recons({n, d});
|
|
114
124
|
self.reconstruct_n(i0, ni, recons.write_ptr());
|
|
115
125
|
return recons;
|
|
116
126
|
})
|
|
117
127
|
.define_method(
|
|
118
128
|
"save",
|
|
119
|
-
[](faiss::IndexBinary
|
|
129
|
+
[](faiss::IndexBinary& self, Rice::String fname) {
|
|
120
130
|
faiss::write_index_binary(&self, fname.c_str());
|
|
121
131
|
})
|
|
122
132
|
.define_singleton_function(
|
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
#include <algorithm>
|
|
2
|
+
#include <cstddef>
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
#include <limits>
|
|
2
5
|
#include <string>
|
|
6
|
+
#include <utility>
|
|
3
7
|
#include <vector>
|
|
4
8
|
|
|
5
9
|
#include <faiss/AutoTune.h>
|
|
@@ -8,17 +12,17 @@
|
|
|
8
12
|
#include <faiss/IndexHNSW.h>
|
|
9
13
|
#include <faiss/IndexIDMap.h>
|
|
10
14
|
#include <faiss/IndexIVFFlat.h>
|
|
11
|
-
#include <faiss/IndexLSH.h>
|
|
12
|
-
#include <faiss/IndexScalarQuantizer.h>
|
|
13
|
-
#include <faiss/IndexPQ.h>
|
|
14
15
|
#include <faiss/IndexIVFPQ.h>
|
|
15
16
|
#include <faiss/IndexIVFPQR.h>
|
|
17
|
+
#include <faiss/IndexLSH.h>
|
|
18
|
+
#include <faiss/IndexPQ.h>
|
|
19
|
+
#include <faiss/IndexScalarQuantizer.h>
|
|
20
|
+
#include <faiss/MetricType.h>
|
|
16
21
|
#include <faiss/index_io.h>
|
|
17
22
|
#include <rice/rice.hpp>
|
|
18
|
-
#include <rice/stl.hpp>
|
|
19
23
|
|
|
20
24
|
#include "numo.hpp"
|
|
21
|
-
#include "
|
|
25
|
+
#include "utils_rb.h"
|
|
22
26
|
|
|
23
27
|
namespace Rice::detail {
|
|
24
28
|
template<>
|
|
@@ -33,10 +37,10 @@ namespace Rice::detail {
|
|
|
33
37
|
|
|
34
38
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
35
39
|
|
|
36
|
-
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
40
|
+
double is_convertible(VALUE /*value*/) { return Convertible::Exact; }
|
|
37
41
|
|
|
38
42
|
faiss::MetricType convert(VALUE x) {
|
|
39
|
-
|
|
43
|
+
std::string s = Object(x).to_s().str();
|
|
40
44
|
if (s == "inner_product") {
|
|
41
45
|
return faiss::METRIC_INNER_PRODUCT;
|
|
42
46
|
} else if (s == "l2") {
|
|
@@ -62,10 +66,10 @@ namespace Rice::detail {
|
|
|
62
66
|
|
|
63
67
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
64
68
|
|
|
65
|
-
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
69
|
+
double is_convertible(VALUE /*value*/) { return Convertible::Exact; }
|
|
66
70
|
|
|
67
71
|
faiss::ScalarQuantizer::QuantizerType convert(VALUE x) {
|
|
68
|
-
|
|
72
|
+
std::string s = Object(x).to_s().str();
|
|
69
73
|
if (s == "qt_8bit") {
|
|
70
74
|
return faiss::ScalarQuantizer::QT_8bit;
|
|
71
75
|
} else if (s == "qt_4bit") {
|
|
@@ -94,17 +98,17 @@ void init_index(Rice::Module& m) {
|
|
|
94
98
|
Rice::define_class_under<faiss::Index>(m, "Index")
|
|
95
99
|
.define_method(
|
|
96
100
|
"d",
|
|
97
|
-
[](faiss::Index
|
|
101
|
+
[](faiss::Index& self) {
|
|
98
102
|
return self.d;
|
|
99
103
|
})
|
|
100
104
|
.define_method(
|
|
101
105
|
"trained?",
|
|
102
|
-
[](faiss::Index
|
|
106
|
+
[](faiss::Index& self) {
|
|
103
107
|
return self.is_trained;
|
|
104
108
|
})
|
|
105
109
|
.define_method(
|
|
106
110
|
"ntotal",
|
|
107
|
-
[](faiss::Index
|
|
111
|
+
[](faiss::Index& self) {
|
|
108
112
|
return self.ntotal;
|
|
109
113
|
})
|
|
110
114
|
.define_method(
|
|
@@ -112,8 +116,8 @@ void init_index(Rice::Module& m) {
|
|
|
112
116
|
[](Rice::Object rb_self, numo::SFloat objects) {
|
|
113
117
|
check_frozen(rb_self);
|
|
114
118
|
|
|
115
|
-
auto
|
|
116
|
-
|
|
119
|
+
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
|
|
120
|
+
size_t n = check_shape(objects, self.d);
|
|
117
121
|
self.train(n, objects.read_ptr());
|
|
118
122
|
})
|
|
119
123
|
.define_method(
|
|
@@ -121,8 +125,8 @@ void init_index(Rice::Module& m) {
|
|
|
121
125
|
[](Rice::Object rb_self, numo::SFloat objects) {
|
|
122
126
|
check_frozen(rb_self);
|
|
123
127
|
|
|
124
|
-
auto
|
|
125
|
-
|
|
128
|
+
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
|
|
129
|
+
size_t n = check_shape(objects, self.d);
|
|
126
130
|
self.add(n, objects.read_ptr());
|
|
127
131
|
})
|
|
128
132
|
.define_method(
|
|
@@ -130,8 +134,8 @@ void init_index(Rice::Module& m) {
|
|
|
130
134
|
[](Rice::Object rb_self, numo::SFloat objects, numo::Int64 ids) {
|
|
131
135
|
check_frozen(rb_self);
|
|
132
136
|
|
|
133
|
-
auto
|
|
134
|
-
|
|
137
|
+
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
|
|
138
|
+
size_t n = check_shape(objects, self.d);
|
|
135
139
|
if (ids.ndim() != 1 || ids.shape()[0] != n) {
|
|
136
140
|
throw Rice::Exception(rb_eArgError, "expected ids to be 1d array with size %d", n);
|
|
137
141
|
}
|
|
@@ -142,36 +146,42 @@ void init_index(Rice::Module& m) {
|
|
|
142
146
|
[](Rice::Object rb_self, numo::Int64 ids) {
|
|
143
147
|
check_frozen(rb_self);
|
|
144
148
|
|
|
145
|
-
auto
|
|
149
|
+
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
|
|
146
150
|
if (ids.ndim() != 1) {
|
|
147
151
|
throw Rice::Exception(rb_eArgError, "expected ids to be 1d array");
|
|
148
152
|
}
|
|
149
|
-
|
|
153
|
+
size_t n = ids.shape()[0];
|
|
150
154
|
faiss::IDSelectorBatch sel(n, ids.read_ptr());
|
|
151
155
|
return self.remove_ids(sel);
|
|
152
156
|
})
|
|
153
157
|
.define_method(
|
|
154
158
|
"search",
|
|
155
|
-
[](Rice::Object rb_self, numo::SFloat objects,
|
|
156
|
-
auto
|
|
157
|
-
|
|
159
|
+
[](Rice::Object rb_self, numo::SFloat objects, int64_t k) {
|
|
160
|
+
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
|
|
161
|
+
size_t n = check_shape(objects, self.d);
|
|
162
|
+
if (k <= 0) {
|
|
163
|
+
throw Rice::Exception(rb_eArgError, "expected k to be positive");
|
|
164
|
+
}
|
|
165
|
+
if (k >= std::numeric_limits<size_t>::max() / n) {
|
|
166
|
+
throw Rice::Exception(rb_eArgError, "k too large");
|
|
167
|
+
}
|
|
158
168
|
|
|
159
|
-
|
|
160
|
-
|
|
169
|
+
numo::SFloat distances({n, static_cast<size_t>(k)});
|
|
170
|
+
numo::Int64 labels({n, static_cast<size_t>(k)});
|
|
161
171
|
|
|
162
172
|
if (rb_self.is_frozen()) {
|
|
163
173
|
// Don't mess with Ruby-owned memory while the GVL is released
|
|
164
|
-
auto objects_ptr = objects.read_ptr();
|
|
174
|
+
const auto* objects_ptr = objects.read_ptr();
|
|
165
175
|
std::vector<float> objects_vec(objects_ptr, objects_ptr + n * self.d);
|
|
166
|
-
std::vector<float> distances_vec(n * k);
|
|
167
|
-
std::vector<int64_t> labels_vec(n * k);
|
|
176
|
+
std::vector<float> distances_vec(n * static_cast<size_t>(k));
|
|
177
|
+
std::vector<int64_t> labels_vec(n * static_cast<size_t>(k));
|
|
168
178
|
|
|
169
179
|
Rice::detail::no_gvl([&] {
|
|
170
180
|
self.search(n, objects_vec.data(), k, distances_vec.data(), labels_vec.data());
|
|
171
181
|
});
|
|
172
182
|
|
|
173
|
-
std::copy(distances_vec
|
|
174
|
-
std::copy(labels_vec
|
|
183
|
+
std::ranges::copy(distances_vec, distances.write_ptr());
|
|
184
|
+
std::ranges::copy(labels_vec, labels.write_ptr());
|
|
175
185
|
} else {
|
|
176
186
|
self.search(n, objects.read_ptr(), k, distances.write_ptr(), labels.write_ptr());
|
|
177
187
|
}
|
|
@@ -186,32 +196,32 @@ void init_index(Rice::Module& m) {
|
|
|
186
196
|
[](Rice::Object rb_self, double val) {
|
|
187
197
|
check_frozen(rb_self);
|
|
188
198
|
|
|
189
|
-
auto
|
|
199
|
+
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
|
|
190
200
|
faiss::ParameterSpace().set_index_parameter(&self, "nprobe", val);
|
|
191
201
|
})
|
|
192
202
|
.define_method(
|
|
193
203
|
"reconstruct",
|
|
194
|
-
[](faiss::Index
|
|
195
|
-
auto d = static_cast<
|
|
196
|
-
|
|
204
|
+
[](faiss::Index& self, int64_t key) {
|
|
205
|
+
auto d = static_cast<size_t>(self.d);
|
|
206
|
+
numo::SFloat recons({d});
|
|
197
207
|
self.reconstruct(key, recons.write_ptr());
|
|
198
208
|
return recons;
|
|
199
209
|
})
|
|
200
210
|
.define_method(
|
|
201
211
|
"reconstruct_batch",
|
|
202
|
-
[](faiss::Index
|
|
212
|
+
[](faiss::Index& self, numo::Int64 ids) {
|
|
203
213
|
if (ids.ndim() != 1) {
|
|
204
214
|
throw Rice::Exception(rb_eArgError, "expected ids to be 1d array");
|
|
205
215
|
}
|
|
206
|
-
auto n = static_cast<
|
|
207
|
-
auto d = static_cast<
|
|
208
|
-
|
|
216
|
+
auto n = static_cast<size_t>(ids.shape()[0]);
|
|
217
|
+
auto d = static_cast<size_t>(self.d);
|
|
218
|
+
numo::SFloat recons({n, d});
|
|
209
219
|
self.reconstruct_batch(n, ids.read_ptr(), recons.write_ptr());
|
|
210
220
|
return recons;
|
|
211
221
|
})
|
|
212
222
|
.define_method(
|
|
213
223
|
"reconstruct_n",
|
|
214
|
-
[](faiss::Index
|
|
224
|
+
[](faiss::Index& self, int64_t i0, int64_t ni) {
|
|
215
225
|
if (ni < 0) {
|
|
216
226
|
throw Rice::Exception(rb_eArgError, "expected n to be non-negative");
|
|
217
227
|
}
|
|
@@ -219,15 +229,15 @@ void init_index(Rice::Module& m) {
|
|
|
219
229
|
if (i0 < 0 || i0 > self.ntotal - ni) {
|
|
220
230
|
throw Rice::Exception(rb_eIndexError, "index out of range");
|
|
221
231
|
}
|
|
222
|
-
auto d = static_cast<
|
|
223
|
-
auto n = static_cast<
|
|
224
|
-
|
|
232
|
+
auto d = static_cast<size_t>(self.d);
|
|
233
|
+
auto n = static_cast<size_t>(ni);
|
|
234
|
+
numo::SFloat recons({n, d});
|
|
225
235
|
self.reconstruct_n(i0, ni, recons.write_ptr());
|
|
226
236
|
return recons;
|
|
227
237
|
})
|
|
228
238
|
.define_method(
|
|
229
239
|
"save",
|
|
230
|
-
[](faiss::Index
|
|
240
|
+
[](faiss::Index& self, Rice::String fname) {
|
|
231
241
|
faiss::write_index(&self, fname.c_str());
|
|
232
242
|
})
|
|
233
243
|
.define_singleton_function(
|
|
@@ -273,13 +283,21 @@ void init_index(Rice::Module& m) {
|
|
|
273
283
|
.define_constructor(Rice::Constructor<faiss::ParameterSpace>())
|
|
274
284
|
.define_method(
|
|
275
285
|
"set_index_parameter",
|
|
276
|
-
[](faiss::ParameterSpace& self, faiss::Index* index,
|
|
277
|
-
self.set_index_parameter(index, name, val);
|
|
286
|
+
[](faiss::ParameterSpace& self, faiss::Index* index, Rice::String name, double val) {
|
|
287
|
+
self.set_index_parameter(index, name.str(), val);
|
|
278
288
|
});
|
|
279
289
|
|
|
280
290
|
Rice::define_class_under<faiss::IndexIDMap, faiss::Index>(m, "IndexIDMap")
|
|
281
291
|
.define_constructor(Rice::Constructor<faiss::IndexIDMap, faiss::Index*>());
|
|
282
292
|
|
|
283
293
|
Rice::define_class_under<faiss::IndexIDMap2, faiss::Index>(m, "IndexIDMap2")
|
|
284
|
-
.define_constructor(Rice::Constructor<faiss::IndexIDMap2, faiss::Index*>())
|
|
294
|
+
.define_constructor(Rice::Constructor<faiss::IndexIDMap2, faiss::Index*>())
|
|
295
|
+
.define_method(
|
|
296
|
+
"id_map",
|
|
297
|
+
[](faiss::IndexIDMap2& self) {
|
|
298
|
+
size_t n = self.id_map.size();
|
|
299
|
+
numo::Int64 ids({n});
|
|
300
|
+
std::ranges::copy(self.id_map, ids.write_ptr());
|
|
301
|
+
return ids;
|
|
302
|
+
});
|
|
285
303
|
}
|
data/ext/faiss/kmeans.cpp
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
#include <algorithm>
|
|
2
|
+
#include <cstddef>
|
|
2
3
|
|
|
3
4
|
#include <faiss/Clustering.h>
|
|
4
5
|
#include <faiss/IndexFlat.h>
|
|
5
6
|
#include <rice/rice.hpp>
|
|
6
7
|
|
|
7
8
|
#include "numo.hpp"
|
|
8
|
-
#include "
|
|
9
|
+
#include "utils_rb.h"
|
|
9
10
|
|
|
10
11
|
void init_kmeans(Rice::Module& m) {
|
|
11
12
|
Rice::define_class_under<faiss::Clustering>(m, "Kmeans")
|
|
@@ -17,27 +18,27 @@ void init_kmeans(Rice::Module& m) {
|
|
|
17
18
|
})
|
|
18
19
|
.define_method(
|
|
19
20
|
"d",
|
|
20
|
-
[](faiss::Clustering
|
|
21
|
+
[](faiss::Clustering& self) {
|
|
21
22
|
return self.d;
|
|
22
23
|
})
|
|
23
24
|
.define_method(
|
|
24
25
|
"k",
|
|
25
|
-
[](faiss::Clustering
|
|
26
|
+
[](faiss::Clustering& self) {
|
|
26
27
|
return self.k;
|
|
27
28
|
})
|
|
28
29
|
.define_method(
|
|
29
30
|
"centroids",
|
|
30
|
-
[](faiss::Clustering
|
|
31
|
-
|
|
32
|
-
std::copy(self.centroids
|
|
31
|
+
[](faiss::Clustering& self) {
|
|
32
|
+
numo::SFloat centroids({self.k, self.d});
|
|
33
|
+
std::ranges::copy(self.centroids, centroids.write_ptr());
|
|
33
34
|
return centroids;
|
|
34
35
|
})
|
|
35
36
|
.define_method(
|
|
36
37
|
"train",
|
|
37
38
|
[](Rice::Object rb_self, numo::SFloat objects) {
|
|
38
|
-
auto
|
|
39
|
-
|
|
40
|
-
|
|
39
|
+
auto& self = *Rice::Data_Object<faiss::Clustering>{rb_self};
|
|
40
|
+
size_t n = check_shape(objects, self.d);
|
|
41
|
+
faiss::IndexFlatL2 index(self.d);
|
|
41
42
|
rb_self.iv_set("@index", index);
|
|
42
43
|
self.train(n, objects.read_ptr(), index);
|
|
43
44
|
});
|
data/ext/faiss/pca_matrix.cpp
CHANGED
|
@@ -1,34 +1,36 @@
|
|
|
1
|
+
#include <cstddef>
|
|
2
|
+
|
|
1
3
|
#include <faiss/VectorTransform.h>
|
|
2
4
|
#include <rice/rice.hpp>
|
|
3
5
|
|
|
4
6
|
#include "numo.hpp"
|
|
5
|
-
#include "
|
|
7
|
+
#include "utils_rb.h"
|
|
6
8
|
|
|
7
9
|
void init_pca_matrix(Rice::Module& m) {
|
|
8
10
|
Rice::define_class_under<faiss::PCAMatrix>(m, "PCAMatrix")
|
|
9
11
|
.define_constructor(Rice::Constructor<faiss::PCAMatrix, int, int>())
|
|
10
12
|
.define_method(
|
|
11
13
|
"d_in",
|
|
12
|
-
[](faiss::PCAMatrix
|
|
14
|
+
[](faiss::PCAMatrix& self) {
|
|
13
15
|
return self.d_in;
|
|
14
16
|
})
|
|
15
17
|
.define_method(
|
|
16
18
|
"d_out",
|
|
17
|
-
[](faiss::PCAMatrix
|
|
19
|
+
[](faiss::PCAMatrix& self) {
|
|
18
20
|
return self.d_out;
|
|
19
21
|
})
|
|
20
22
|
.define_method(
|
|
21
23
|
"train",
|
|
22
|
-
[](faiss::PCAMatrix
|
|
23
|
-
|
|
24
|
+
[](faiss::PCAMatrix& self, numo::SFloat objects) {
|
|
25
|
+
size_t n = check_shape(objects, self.d_in);
|
|
24
26
|
self.train(n, objects.read_ptr());
|
|
25
27
|
})
|
|
26
28
|
.define_method(
|
|
27
29
|
"apply",
|
|
28
|
-
[](faiss::PCAMatrix
|
|
29
|
-
|
|
30
|
+
[](faiss::PCAMatrix& self, numo::SFloat objects) {
|
|
31
|
+
size_t n = check_shape(objects, self.d_in);
|
|
30
32
|
|
|
31
|
-
|
|
33
|
+
numo::SFloat ary({n, static_cast<size_t>(self.d_out)});
|
|
32
34
|
self.apply_noalloc(n, objects.read_ptr(), ary.write_ptr());
|
|
33
35
|
return ary;
|
|
34
36
|
});
|
|
@@ -1,50 +1,52 @@
|
|
|
1
|
+
#include <cstddef>
|
|
2
|
+
|
|
1
3
|
#include <faiss/impl/ProductQuantizer.h>
|
|
2
4
|
#include <faiss/index_io.h>
|
|
3
5
|
#include <rice/rice.hpp>
|
|
4
6
|
|
|
5
7
|
#include "numo.hpp"
|
|
6
|
-
#include "
|
|
8
|
+
#include "utils_rb.h"
|
|
7
9
|
|
|
8
10
|
void init_product_quantizer(Rice::Module& m) {
|
|
9
11
|
Rice::define_class_under<faiss::ProductQuantizer>(m, "ProductQuantizer")
|
|
10
12
|
.define_constructor(Rice::Constructor<faiss::ProductQuantizer, size_t, size_t, size_t>())
|
|
11
13
|
.define_method(
|
|
12
14
|
"d",
|
|
13
|
-
[](faiss::ProductQuantizer
|
|
15
|
+
[](faiss::ProductQuantizer& self) {
|
|
14
16
|
return self.d;
|
|
15
17
|
})
|
|
16
18
|
.define_method(
|
|
17
19
|
"m",
|
|
18
|
-
[](faiss::ProductQuantizer
|
|
20
|
+
[](faiss::ProductQuantizer& self) {
|
|
19
21
|
return self.M;
|
|
20
22
|
})
|
|
21
23
|
.define_method(
|
|
22
24
|
"train",
|
|
23
|
-
[](faiss::ProductQuantizer
|
|
24
|
-
|
|
25
|
+
[](faiss::ProductQuantizer& self, numo::SFloat objects) {
|
|
26
|
+
size_t n = check_shape(objects, self.d);
|
|
25
27
|
self.train(n, objects.read_ptr());
|
|
26
28
|
})
|
|
27
29
|
.define_method(
|
|
28
30
|
"compute_codes",
|
|
29
|
-
[](faiss::ProductQuantizer
|
|
30
|
-
|
|
31
|
+
[](faiss::ProductQuantizer& self, numo::SFloat objects) {
|
|
32
|
+
size_t n = check_shape(objects, self.d);
|
|
31
33
|
|
|
32
|
-
|
|
34
|
+
numo::UInt8 codes({n, self.M});
|
|
33
35
|
self.compute_codes(objects.read_ptr(), codes.write_ptr(), n);
|
|
34
36
|
return codes;
|
|
35
37
|
})
|
|
36
38
|
.define_method(
|
|
37
39
|
"decode",
|
|
38
|
-
[](faiss::ProductQuantizer
|
|
39
|
-
|
|
40
|
+
[](faiss::ProductQuantizer& self, numo::UInt8 objects) {
|
|
41
|
+
size_t n = check_shape(objects, self.M);
|
|
40
42
|
|
|
41
|
-
|
|
43
|
+
numo::SFloat x({n, self.d});
|
|
42
44
|
self.decode(objects.read_ptr(), x.write_ptr(), n);
|
|
43
45
|
return x;
|
|
44
46
|
})
|
|
45
47
|
.define_method(
|
|
46
48
|
"save",
|
|
47
|
-
[](faiss::ProductQuantizer
|
|
49
|
+
[](faiss::ProductQuantizer& self, Rice::String fname) {
|
|
48
50
|
faiss::write_ProductQuantizer(&self, fname.c_str());
|
|
49
51
|
})
|
|
50
52
|
.define_singleton_function(
|
|
@@ -1,14 +1,16 @@
|
|
|
1
|
+
#include <cstddef>
|
|
2
|
+
|
|
1
3
|
#include <rice/rice.hpp>
|
|
2
4
|
|
|
3
5
|
#include "numo.hpp"
|
|
4
|
-
#include "
|
|
6
|
+
#include "utils_rb.h"
|
|
5
7
|
|
|
6
8
|
size_t check_shape(const numo::NArray& objects, size_t k) {
|
|
7
|
-
|
|
9
|
+
size_t ndim = objects.ndim();
|
|
8
10
|
if (ndim != 2) {
|
|
9
11
|
throw Rice::Exception(rb_eArgError, "expected 2 dimensions, not %d", ndim);
|
|
10
12
|
}
|
|
11
|
-
|
|
13
|
+
size_t* shape = objects.shape();
|
|
12
14
|
if (shape[1] != k) {
|
|
13
15
|
throw Rice::Exception(rb_eArgError, "expected 2nd dimension to be %d, not %d", k, shape[1]);
|
|
14
16
|
}
|