faiss 0.1.7 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,8 +1,5 @@
1
1
  #include <faiss/VectorTransform.h>
2
2
 
3
- #include <rice/Constructor.hpp>
4
- #include <rice/Module.hpp>
5
-
6
3
  #include "utils.h"
7
4
 
8
5
  void init_pca_matrix(Rice::Module& m) {
@@ -10,25 +7,27 @@ void init_pca_matrix(Rice::Module& m) {
10
7
  .define_constructor(Rice::Constructor<faiss::PCAMatrix, int, int>())
11
8
  .define_method(
12
9
  "d_in",
13
- *[](faiss::PCAMatrix &self) {
10
+ [](faiss::PCAMatrix &self) {
14
11
  return self.d_in;
15
12
  })
16
13
  .define_method(
17
14
  "d_out",
18
- *[](faiss::PCAMatrix &self) {
15
+ [](faiss::PCAMatrix &self) {
19
16
  return self.d_out;
20
17
  })
21
18
  .define_method(
22
- "_train",
23
- *[](faiss::PCAMatrix &self, int64_t n, Rice::Object o) {
24
- const float *x = float_array(o);
25
- self.train(n, x);
19
+ "train",
20
+ [](faiss::PCAMatrix &self, numo::SFloat objects) {
21
+ auto n = check_shape(objects, self.d_in);
22
+ self.train(n, objects.read_ptr());
26
23
  })
27
24
  .define_method(
28
- "_apply",
29
- *[](faiss::PCAMatrix &self, int64_t n, Rice::Object o) {
30
- const float *x = float_array(o);
31
- float* res = self.apply(n, x);
32
- return result(res, n * self.d_out);
25
+ "apply",
26
+ [](faiss::PCAMatrix &self, numo::SFloat objects) {
27
+ auto n = check_shape(objects, self.d_in);
28
+
29
+ auto ary = numo::SFloat({n, static_cast<size_t>(self.d_out)});
30
+ self.apply_noalloc(n, objects.read_ptr(), ary.write_ptr());
31
+ return ary;
33
32
  });
34
33
  }
@@ -1,9 +1,6 @@
1
1
  #include <faiss/impl/ProductQuantizer.h>
2
2
  #include <faiss/index_io.h>
3
3
 
4
- #include <rice/Constructor.hpp>
5
- #include <rice/Module.hpp>
6
-
7
4
  #include "utils.h"
8
5
 
9
6
  void init_product_quantizer(Rice::Module& m) {
@@ -11,44 +8,46 @@ void init_product_quantizer(Rice::Module& m) {
11
8
  .define_constructor(Rice::Constructor<faiss::ProductQuantizer, size_t, size_t, size_t>())
12
9
  .define_method(
13
10
  "d",
14
- *[](faiss::ProductQuantizer &self) {
11
+ [](faiss::ProductQuantizer &self) {
15
12
  return self.d;
16
13
  })
17
14
  .define_method(
18
15
  "m",
19
- *[](faiss::ProductQuantizer &self) {
16
+ [](faiss::ProductQuantizer &self) {
20
17
  return self.M;
21
18
  })
22
19
  .define_method(
23
- "_train",
24
- *[](faiss::ProductQuantizer &self, int n, Rice::Object o) {
25
- const float *x = float_array(o);
26
- self.train(n, x);
20
+ "train",
21
+ [](faiss::ProductQuantizer &self, numo::SFloat objects) {
22
+ auto n = check_shape(objects, self.d);
23
+ self.train(n, objects.read_ptr());
27
24
  })
28
25
  .define_method(
29
- "_compute_codes",
30
- *[](faiss::ProductQuantizer &self, int n, Rice::Object o) {
31
- const float *x = float_array(o);
32
- uint8_t *codes = new uint8_t[n * self.M];
33
- self.compute_codes(x, codes, n);
34
- return result(codes, n * self.M);
26
+ "compute_codes",
27
+ [](faiss::ProductQuantizer &self, numo::SFloat objects) {
28
+ auto n = check_shape(objects, self.d);
29
+
30
+ auto codes = numo::UInt8({n, self.M});
31
+ self.compute_codes(objects.read_ptr(), codes.write_ptr(), n);
32
+ return codes;
35
33
  })
36
34
  .define_method(
37
- "_decode",
38
- *[](faiss::ProductQuantizer &self, int n, Rice::Object o) {
39
- const uint8_t *codes = uint8_array(o);
40
- float *x = new float[n * self.d];
41
- self.decode(codes, x, n);
42
- return result(x, n * self.d);
35
+ "decode",
36
+ [](faiss::ProductQuantizer &self, numo::UInt8 objects) {
37
+ auto n = check_shape(objects, self.M);
38
+
39
+ auto x = numo::SFloat({n, self.d});
40
+ self.decode(objects.read_ptr(), x.write_ptr(), n);
41
+ return x;
43
42
  })
44
43
  .define_method(
45
44
  "save",
46
- *[](faiss::ProductQuantizer &self, const char *fname) {
45
+ [](faiss::ProductQuantizer &self, const char *fname) {
47
46
  faiss::write_ProductQuantizer(&self, fname);
48
47
  })
49
- .define_singleton_method(
48
+ .define_singleton_function(
50
49
  "load",
51
- *[](const char *fname) {
50
+ [](const char *fname) {
52
51
  return faiss::read_ProductQuantizer(fname);
53
52
  });
54
53
  }
data/ext/faiss/utils.cpp CHANGED
@@ -1,40 +1,13 @@
1
1
  #include "utils.h"
2
2
 
3
- #include <rice/Object.hpp>
4
- #include <rice/String.hpp>
5
-
6
- float* float_array(Rice::Object o)
7
- {
8
- Rice::String s = o.call("to_binary");
9
- return (float*) s.c_str();
10
- }
11
-
12
- uint8_t* uint8_array(Rice::Object o)
13
- {
14
- Rice::String s = o.call("to_binary");
15
- return (uint8_t*) s.c_str();
16
- }
17
-
18
- // TODO return Numo::SFloat
19
- Rice::String result(float* ptr, int64_t length)
20
- {
21
- return Rice::String(std::string((char*) ptr, length * sizeof(float)));
22
- }
23
-
24
- // TODO return Numo::UInt8
25
- Rice::String result(uint8_t* ptr, int64_t length)
26
- {
27
- return Rice::String(std::string((char*) ptr, length * sizeof(uint8_t)));
28
- }
29
-
30
- // TODO return Numo::Int32
31
- Rice::String result(int32_t* ptr, int64_t length)
32
- {
33
- return Rice::String(std::string((char*) ptr, length * sizeof(int32_t)));
34
- }
35
-
36
- // TODO return Numo::Int64
37
- Rice::String result(int64_t* ptr, int64_t length)
38
- {
39
- return Rice::String(std::string((char*) ptr, length * sizeof(int64_t)));
3
+ size_t check_shape(numo::NArray objects, size_t k) {
4
+ auto ndim = objects.ndim();
5
+ if (ndim != 2) {
6
+ throw Rice::Exception(rb_eArgError, "expected 2 dimensions, not %d", ndim);
7
+ }
8
+ auto shape = objects.shape();
9
+ if (shape[1] != k) {
10
+ throw Rice::Exception(rb_eArgError, "expected 2nd dimension to be %d, not %d", k, shape[1]);
11
+ }
12
+ return shape[0];
40
13
  }
data/ext/faiss/utils.h CHANGED
@@ -1,16 +1,5 @@
1
1
  #pragma once
2
2
 
3
- #include <rice/Object.hpp>
4
- #include <rice/String.hpp>
3
+ #include "numo.hpp"
5
4
 
6
- float* float_array(Rice::Object o);
7
- uint8_t* uint8_array(Rice::Object o);
8
-
9
- // TODO return Numo::SFloat
10
- Rice::String result(float* ptr, int64_t length);
11
- // TODO return Numo::UInt8
12
- Rice::String result(uint8_t* ptr, int64_t length);
13
- // TODO return Numo::Int32
14
- Rice::String result(int32_t* ptr, int64_t length);
15
- // TODO return Numo::Int64
16
- Rice::String result(int64_t* ptr, int64_t length);
5
+ size_t check_shape(numo::NArray objects, size_t k);
data/lib/faiss.rb CHANGED
@@ -5,9 +5,4 @@ require "numo/narray"
5
5
  require "faiss/ext"
6
6
 
7
7
  # modules
8
- require "faiss/index"
9
- require "faiss/index_binary"
10
- require "faiss/kmeans"
11
- require "faiss/pca_matrix"
12
- require "faiss/product_quantizer"
13
8
  require "faiss/version"
data/lib/faiss/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Faiss
2
- VERSION = "0.1.7"
2
+ VERSION = "0.2.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: faiss
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.7
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-03-26 00:00:00.000000000 Z
11
+ date: 2021-05-23 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '2.2'
19
+ version: 4.0.2
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '2.2'
26
+ version: 4.0.2
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: numo-narray
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -53,16 +53,12 @@ files:
53
53
  - ext/faiss/index.cpp
54
54
  - ext/faiss/index_binary.cpp
55
55
  - ext/faiss/kmeans.cpp
56
+ - ext/faiss/numo.hpp
56
57
  - ext/faiss/pca_matrix.cpp
57
58
  - ext/faiss/product_quantizer.cpp
58
59
  - ext/faiss/utils.cpp
59
60
  - ext/faiss/utils.h
60
61
  - lib/faiss.rb
61
- - lib/faiss/index.rb
62
- - lib/faiss/index_binary.rb
63
- - lib/faiss/kmeans.rb
64
- - lib/faiss/pca_matrix.rb
65
- - lib/faiss/product_quantizer.rb
66
62
  - lib/faiss/version.rb
67
63
  - vendor/faiss/LICENSE
68
64
  - vendor/faiss/faiss/AutoTune.cpp
@@ -257,7 +253,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
257
253
  requirements:
258
254
  - - ">="
259
255
  - !ruby/object:Gem::Version
260
- version: '2.4'
256
+ version: '2.6'
261
257
  required_rubygems_version: !ruby/object:Gem::Requirement
262
258
  requirements:
263
259
  - - ">="
data/lib/faiss/index.rb DELETED
@@ -1,20 +0,0 @@
1
- module Faiss
2
- class Index
3
- def train(objects)
4
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
5
- _train(objects.shape[0], objects)
6
- end
7
-
8
- def add(objects)
9
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
10
- _add(objects.shape[0], objects)
11
- end
12
-
13
- def search(objects, k)
14
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
15
- n = objects.shape[0]
16
- distances, labels = _search(n, objects, k)
17
- [Numo::SFloat.from_binary(distances).reshape(n, k), Numo::Int64.from_binary(labels).reshape(n, k)]
18
- end
19
- end
20
- end
@@ -1,20 +0,0 @@
1
- module Faiss
2
- class IndexBinary
3
- def train(objects)
4
- objects = Numo::UInt8.cast(objects) unless objects.is_a?(Numo::UInt8)
5
- _train(objects.shape[0], objects)
6
- end
7
-
8
- def add(objects)
9
- objects = Numo::UInt8.cast(objects) unless objects.is_a?(Numo::UInt8)
10
- _add(objects.shape[0], objects)
11
- end
12
-
13
- def search(objects, k)
14
- objects = Numo::UInt8.cast(objects) unless objects.is_a?(Numo::UInt8)
15
- n = objects.shape[0]
16
- distances, labels = _search(n, objects, k)
17
- [Numo::UInt32.from_binary(distances).reshape(n, k), Numo::Int64.from_binary(labels).reshape(n, k)]
18
- end
19
- end
20
- end
data/lib/faiss/kmeans.rb DELETED
@@ -1,15 +0,0 @@
1
- module Faiss
2
- class Kmeans
3
- attr_reader :index
4
-
5
- def train(objects)
6
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
7
- @index = IndexFlatL2.new(d)
8
- _train(objects.shape[0], objects, @index)
9
- end
10
-
11
- def centroids
12
- Numo::SFloat.from_binary(_centroids).reshape(k, d)
13
- end
14
- end
15
- end
@@ -1,15 +0,0 @@
1
- module Faiss
2
- class PCAMatrix
3
- def train(objects)
4
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
5
- _train(objects.shape[0], objects)
6
- end
7
-
8
- def apply(objects)
9
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
10
- n = objects.shape[0]
11
- res = _apply(n, objects)
12
- Numo::SFloat.from_binary(res).reshape(n, d_out)
13
- end
14
- end
15
- end
@@ -1,22 +0,0 @@
1
- module Faiss
2
- class ProductQuantizer
3
- def train(objects)
4
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
5
- _train(objects.shape[0], objects)
6
- end
7
-
8
- def compute_codes(objects)
9
- objects = Numo::SFloat.cast(objects) unless objects.is_a?(Numo::SFloat)
10
- n = objects.shape[0]
11
- res = _compute_codes(n, objects)
12
- Numo::UInt8.from_binary(res).reshape(n, m)
13
- end
14
-
15
- def decode(objects)
16
- objects = Numo::UInt8.cast(objects) unless objects.is_a?(Numo::UInt8)
17
- n = objects.shape[0]
18
- res = _decode(n, objects)
19
- Numo::SFloat.from_binary(res).reshape(n, d)
20
- end
21
- end
22
- end