torch-rb 0.8.0 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/codegen/generate_functions.rb +33 -3
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/ext.cpp +8 -0
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/lib/torch/version.rb +1 -1
- metadata +9 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d52cf2bf4770e9166623614f6071e5180e9492b0063757be8bdec73a2c930b38
|
4
|
+
data.tar.gz: 1b650a3277d1aebe28cdd5d75ce54420feba7a1c9a2046335c9271c72eb3f74f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3afad67c5ca6cedc4925dab1aadc37541daac6752a8b508a8d4bd6cd3b25b71600ed32625a65651bf150d6ddd519ec86ca7b9ccd2f12022366ed692603f65c1a
|
7
|
+
data.tar.gz: 970fa451044ce68d60e13f2da297ea8e20e53f01a1ffee31f152f190baf5d8f709f9be3defcd6b24739499cb400c5448abd43939a451be826e87b2e62eba3cac
|
data/CHANGELOG.md
CHANGED
@@ -11,6 +11,9 @@ def generate_functions
|
|
11
11
|
generate_files("torch", :define_singleton_method, functions[:torch])
|
12
12
|
generate_files("tensor", :define_method, functions[:tensor])
|
13
13
|
generate_files("nn", :define_singleton_method, functions[:nn])
|
14
|
+
generate_files("fft", :define_singleton_method, functions[:fft])
|
15
|
+
generate_files("linalg", :define_singleton_method, functions[:linalg])
|
16
|
+
generate_files("special", :define_singleton_method, functions[:special])
|
14
17
|
end
|
15
18
|
|
16
19
|
def load_functions
|
@@ -38,10 +41,26 @@ end
|
|
38
41
|
|
39
42
|
def group_functions(functions)
|
40
43
|
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
44
|
+
linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" }
|
45
|
+
fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
|
46
|
+
special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
|
47
|
+
unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
|
41
48
|
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
42
49
|
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
43
50
|
|
44
|
-
|
51
|
+
if unexpected_functions.any?
|
52
|
+
unexpected_modules = unexpected_functions.map(&:python_module).uniq
|
53
|
+
raise "Unexpected modules: #{unexpected_modules.join(", ")}"
|
54
|
+
end
|
55
|
+
|
56
|
+
{
|
57
|
+
torch: torch_functions,
|
58
|
+
tensor: tensor_functions,
|
59
|
+
nn: nn_functions,
|
60
|
+
linalg: linalg_functions,
|
61
|
+
fft: fft_functions,
|
62
|
+
special: special_functions
|
63
|
+
}
|
45
64
|
end
|
46
65
|
|
47
66
|
def generate_files(type, def_method, functions)
|
@@ -111,11 +130,14 @@ def generate_attach_def(name, type, def_method)
|
|
111
130
|
end
|
112
131
|
|
113
132
|
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
|
133
|
+
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
|
134
|
+
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
|
135
|
+
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
|
114
136
|
|
115
137
|
# cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
116
138
|
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
117
139
|
|
118
|
-
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}
|
140
|
+
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
|
119
141
|
end
|
120
142
|
|
121
143
|
def generate_method_def(name, functions, type, def_method)
|
@@ -128,7 +150,7 @@ def generate_method_def(name, functions, type, def_method)
|
|
128
150
|
|
129
151
|
template = <<~EOS
|
130
152
|
// #{name}
|
131
|
-
static VALUE #{type}
|
153
|
+
static VALUE #{full_name(name, type)}(int argc, VALUE* argv, VALUE self_)
|
132
154
|
{
|
133
155
|
HANDLE_TH_ERRORS#{assign_self}
|
134
156
|
static RubyArgParser parser({
|
@@ -560,3 +582,11 @@ def signature_type(param)
|
|
560
582
|
type += "?" if param[:optional]
|
561
583
|
type
|
562
584
|
end
|
585
|
+
|
586
|
+
def full_name(name, type)
|
587
|
+
if %w(fft linalg special).include?(type) && name.start_with?("#{type}_")
|
588
|
+
name
|
589
|
+
else
|
590
|
+
"#{type}_#{name}"
|
591
|
+
end
|
592
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_backends(Rice::Module& m) {
|
8
|
+
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
|
+
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
|
+
.add_handler<torch::Error>(handle_error)
|
12
|
+
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
|
+
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
|
+
.add_handler<torch::Error>(handle_error)
|
16
|
+
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
}
|
data/ext/torch/ext.cpp
CHANGED
@@ -2,10 +2,14 @@
|
|
2
2
|
|
3
3
|
#include <rice/rice.hpp>
|
4
4
|
|
5
|
+
void init_fft(Rice::Module& m);
|
6
|
+
void init_linalg(Rice::Module& m);
|
5
7
|
void init_nn(Rice::Module& m);
|
8
|
+
void init_special(Rice::Module& m);
|
6
9
|
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
7
10
|
void init_torch(Rice::Module& m);
|
8
11
|
|
12
|
+
void init_backends(Rice::Module& m);
|
9
13
|
void init_cuda(Rice::Module& m);
|
10
14
|
void init_device(Rice::Module& m);
|
11
15
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
@@ -27,7 +31,11 @@ void Init_ext()
|
|
27
31
|
init_torch(m);
|
28
32
|
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
29
33
|
init_nn(m);
|
34
|
+
init_fft(m);
|
35
|
+
init_linalg(m);
|
36
|
+
init_special(m);
|
30
37
|
|
38
|
+
init_backends(m);
|
31
39
|
init_cuda(m);
|
32
40
|
init_device(m);
|
33
41
|
init_ivalue(m, rb_cIValue);
|
data/ext/torch/fft.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "fft_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_fft(Rice::Module& m) {
|
10
|
+
auto rb_mFFT = Rice::define_module_under(m, "FFT");
|
11
|
+
rb_mFFT.add_handler<torch::Error>(handle_error);
|
12
|
+
add_fft_functions(rb_mFFT);
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "linalg_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_linalg(Rice::Module& m) {
|
10
|
+
auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
|
11
|
+
rb_mLinalg.add_handler<torch::Error>(handle_error);
|
12
|
+
add_linalg_functions(rb_mLinalg);
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "special_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_special(Rice::Module& m) {
|
10
|
+
auto rb_mSpecial = Rice::define_module_under(m, "Special");
|
11
|
+
rb_mSpecial.add_handler<torch::Error>(handle_error);
|
12
|
+
add_special_functions(rb_mSpecial);
|
13
|
+
}
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.8.
|
4
|
+
version: 0.8.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-06-
|
11
|
+
date: 2021-06-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -37,16 +37,23 @@ files:
|
|
37
37
|
- codegen/function.rb
|
38
38
|
- codegen/generate_functions.rb
|
39
39
|
- codegen/native_functions.yaml
|
40
|
+
- ext/torch/backends.cpp
|
40
41
|
- ext/torch/cuda.cpp
|
41
42
|
- ext/torch/device.cpp
|
42
43
|
- ext/torch/ext.cpp
|
43
44
|
- ext/torch/extconf.rb
|
45
|
+
- ext/torch/fft.cpp
|
46
|
+
- ext/torch/fft_functions.h
|
44
47
|
- ext/torch/ivalue.cpp
|
48
|
+
- ext/torch/linalg.cpp
|
49
|
+
- ext/torch/linalg_functions.h
|
45
50
|
- ext/torch/nn.cpp
|
46
51
|
- ext/torch/nn_functions.h
|
47
52
|
- ext/torch/random.cpp
|
48
53
|
- ext/torch/ruby_arg_parser.cpp
|
49
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/special.cpp
|
56
|
+
- ext/torch/special_functions.h
|
50
57
|
- ext/torch/templates.h
|
51
58
|
- ext/torch/tensor.cpp
|
52
59
|
- ext/torch/tensor_functions.h
|