torch-rb 0.8.0 → 0.8.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/codegen/generate_functions.rb +33 -3
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/ext.cpp +8 -0
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/lib/torch/version.rb +1 -1
- metadata +9 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d52cf2bf4770e9166623614f6071e5180e9492b0063757be8bdec73a2c930b38
|
4
|
+
data.tar.gz: 1b650a3277d1aebe28cdd5d75ce54420feba7a1c9a2046335c9271c72eb3f74f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3afad67c5ca6cedc4925dab1aadc37541daac6752a8b508a8d4bd6cd3b25b71600ed32625a65651bf150d6ddd519ec86ca7b9ccd2f12022366ed692603f65c1a
|
7
|
+
data.tar.gz: 970fa451044ce68d60e13f2da297ea8e20e53f01a1ffee31f152f190baf5d8f709f9be3defcd6b24739499cb400c5448abd43939a451be826e87b2e62eba3cac
|
data/CHANGELOG.md
CHANGED
@@ -11,6 +11,9 @@ def generate_functions
|
|
11
11
|
generate_files("torch", :define_singleton_method, functions[:torch])
|
12
12
|
generate_files("tensor", :define_method, functions[:tensor])
|
13
13
|
generate_files("nn", :define_singleton_method, functions[:nn])
|
14
|
+
generate_files("fft", :define_singleton_method, functions[:fft])
|
15
|
+
generate_files("linalg", :define_singleton_method, functions[:linalg])
|
16
|
+
generate_files("special", :define_singleton_method, functions[:special])
|
14
17
|
end
|
15
18
|
|
16
19
|
def load_functions
|
@@ -38,10 +41,26 @@ end
|
|
38
41
|
|
39
42
|
def group_functions(functions)
|
40
43
|
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
44
|
+
linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" }
|
45
|
+
fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
|
46
|
+
special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
|
47
|
+
unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
|
41
48
|
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
42
49
|
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
43
50
|
|
44
|
-
|
51
|
+
if unexpected_functions.any?
|
52
|
+
unexpected_modules = unexpected_functions.map(&:python_module).uniq
|
53
|
+
raise "Unexpected modules: #{unexpected_modules.join(", ")}"
|
54
|
+
end
|
55
|
+
|
56
|
+
{
|
57
|
+
torch: torch_functions,
|
58
|
+
tensor: tensor_functions,
|
59
|
+
nn: nn_functions,
|
60
|
+
linalg: linalg_functions,
|
61
|
+
fft: fft_functions,
|
62
|
+
special: special_functions
|
63
|
+
}
|
45
64
|
end
|
46
65
|
|
47
66
|
def generate_files(type, def_method, functions)
|
@@ -111,11 +130,14 @@ def generate_attach_def(name, type, def_method)
|
|
111
130
|
end
|
112
131
|
|
113
132
|
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
|
133
|
+
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
|
134
|
+
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
|
135
|
+
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
|
114
136
|
|
115
137
|
# cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
116
138
|
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
117
139
|
|
118
|
-
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}
|
140
|
+
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
|
119
141
|
end
|
120
142
|
|
121
143
|
def generate_method_def(name, functions, type, def_method)
|
@@ -128,7 +150,7 @@ def generate_method_def(name, functions, type, def_method)
|
|
128
150
|
|
129
151
|
template = <<~EOS
|
130
152
|
// #{name}
|
131
|
-
static VALUE #{type}
|
153
|
+
static VALUE #{full_name(name, type)}(int argc, VALUE* argv, VALUE self_)
|
132
154
|
{
|
133
155
|
HANDLE_TH_ERRORS#{assign_self}
|
134
156
|
static RubyArgParser parser({
|
@@ -560,3 +582,11 @@ def signature_type(param)
|
|
560
582
|
type += "?" if param[:optional]
|
561
583
|
type
|
562
584
|
end
|
585
|
+
|
586
|
+
def full_name(name, type)
|
587
|
+
if %w(fft linalg special).include?(type) && name.start_with?("#{type}_")
|
588
|
+
name
|
589
|
+
else
|
590
|
+
"#{type}_#{name}"
|
591
|
+
end
|
592
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_backends(Rice::Module& m) {
|
8
|
+
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
|
+
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
|
+
.add_handler<torch::Error>(handle_error)
|
12
|
+
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
|
+
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
|
+
.add_handler<torch::Error>(handle_error)
|
16
|
+
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
}
|
data/ext/torch/ext.cpp
CHANGED
@@ -2,10 +2,14 @@
|
|
2
2
|
|
3
3
|
#include <rice/rice.hpp>
|
4
4
|
|
5
|
+
void init_fft(Rice::Module& m);
|
6
|
+
void init_linalg(Rice::Module& m);
|
5
7
|
void init_nn(Rice::Module& m);
|
8
|
+
void init_special(Rice::Module& m);
|
6
9
|
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
7
10
|
void init_torch(Rice::Module& m);
|
8
11
|
|
12
|
+
void init_backends(Rice::Module& m);
|
9
13
|
void init_cuda(Rice::Module& m);
|
10
14
|
void init_device(Rice::Module& m);
|
11
15
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
@@ -27,7 +31,11 @@ void Init_ext()
|
|
27
31
|
init_torch(m);
|
28
32
|
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
29
33
|
init_nn(m);
|
34
|
+
init_fft(m);
|
35
|
+
init_linalg(m);
|
36
|
+
init_special(m);
|
30
37
|
|
38
|
+
init_backends(m);
|
31
39
|
init_cuda(m);
|
32
40
|
init_device(m);
|
33
41
|
init_ivalue(m, rb_cIValue);
|
data/ext/torch/fft.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "fft_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_fft(Rice::Module& m) {
|
10
|
+
auto rb_mFFT = Rice::define_module_under(m, "FFT");
|
11
|
+
rb_mFFT.add_handler<torch::Error>(handle_error);
|
12
|
+
add_fft_functions(rb_mFFT);
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "linalg_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_linalg(Rice::Module& m) {
|
10
|
+
auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
|
11
|
+
rb_mLinalg.add_handler<torch::Error>(handle_error);
|
12
|
+
add_linalg_functions(rb_mLinalg);
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "special_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_special(Rice::Module& m) {
|
10
|
+
auto rb_mSpecial = Rice::define_module_under(m, "Special");
|
11
|
+
rb_mSpecial.add_handler<torch::Error>(handle_error);
|
12
|
+
add_special_functions(rb_mSpecial);
|
13
|
+
}
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.8.
|
4
|
+
version: 0.8.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-06-
|
11
|
+
date: 2021-06-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -37,16 +37,23 @@ files:
|
|
37
37
|
- codegen/function.rb
|
38
38
|
- codegen/generate_functions.rb
|
39
39
|
- codegen/native_functions.yaml
|
40
|
+
- ext/torch/backends.cpp
|
40
41
|
- ext/torch/cuda.cpp
|
41
42
|
- ext/torch/device.cpp
|
42
43
|
- ext/torch/ext.cpp
|
43
44
|
- ext/torch/extconf.rb
|
45
|
+
- ext/torch/fft.cpp
|
46
|
+
- ext/torch/fft_functions.h
|
44
47
|
- ext/torch/ivalue.cpp
|
48
|
+
- ext/torch/linalg.cpp
|
49
|
+
- ext/torch/linalg_functions.h
|
45
50
|
- ext/torch/nn.cpp
|
46
51
|
- ext/torch/nn_functions.h
|
47
52
|
- ext/torch/random.cpp
|
48
53
|
- ext/torch/ruby_arg_parser.cpp
|
49
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/special.cpp
|
56
|
+
- ext/torch/special_functions.h
|
50
57
|
- ext/torch/templates.h
|
51
58
|
- ext/torch/tensor.cpp
|
52
59
|
- ext/torch/tensor_functions.h
|