torch-rb 0.8.0 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fdc46d7eba97851841de58e63b35badbd579ce1bc2a3e311998db03956cb84ad
4
- data.tar.gz: 944520049bd683d303e4587d18394eeb13102aa9f4addeca87875ee214806de4
3
+ metadata.gz: d52cf2bf4770e9166623614f6071e5180e9492b0063757be8bdec73a2c930b38
4
+ data.tar.gz: 1b650a3277d1aebe28cdd5d75ce54420feba7a1c9a2046335c9271c72eb3f74f
5
5
  SHA512:
6
- metadata.gz: 9026cbd9b0cdf286e79e7b296f2fae95712bd553c2cd58a34044feeed482f1ed2630b7b1e852a82e87d1011478b7e24d442d634f07de67b2c0a6a8e1cf6ade9a
7
- data.tar.gz: 193e25fc5065d42160b0584f3e53512606c1b4329c6d872b90a972d7c4d46a4be81eef8c709e3060339c96758fe504dc8618da9e69ff411d9351d528983da72b
6
+ metadata.gz: 3afad67c5ca6cedc4925dab1aadc37541daac6752a8b508a8d4bd6cd3b25b71600ed32625a65651bf150d6ddd519ec86ca7b9ccd2f12022366ed692603f65c1a
7
+ data.tar.gz: 970fa451044ce68d60e13f2da297ea8e20e53f01a1ffee31f152f190baf5d8f709f9be3defcd6b24739499cb400c5448abd43939a451be826e87b2e62eba3cac
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ ## 0.8.1 (2021-06-15)
2
+
3
+ - Added `Backends` module
4
+ - Added `FFT` module
5
+ - Added `Linalg` module
6
+ - Added `Special` module
7
+
1
8
  ## 0.8.0 (2021-06-15)
2
9
 
3
10
  - Updated LibTorch to 1.9.0
@@ -11,6 +11,9 @@ def generate_functions
11
11
  generate_files("torch", :define_singleton_method, functions[:torch])
12
12
  generate_files("tensor", :define_method, functions[:tensor])
13
13
  generate_files("nn", :define_singleton_method, functions[:nn])
14
+ generate_files("fft", :define_singleton_method, functions[:fft])
15
+ generate_files("linalg", :define_singleton_method, functions[:linalg])
16
+ generate_files("special", :define_singleton_method, functions[:special])
14
17
  end
15
18
 
16
19
  def load_functions
@@ -38,10 +41,26 @@ end
38
41
 
39
42
  def group_functions(functions)
40
43
  nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
44
+ linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" }
45
+ fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
46
+ special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
47
+ unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
41
48
  torch_functions = other_functions.select { |f| f.variants.include?("function") }
42
49
  tensor_functions = other_functions.select { |f| f.variants.include?("method") }
43
50
 
44
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
51
+ if unexpected_functions.any?
52
+ unexpected_modules = unexpected_functions.map(&:python_module).uniq
53
+ raise "Unexpected modules: #{unexpected_modules.join(", ")}"
54
+ end
55
+
56
+ {
57
+ torch: torch_functions,
58
+ tensor: tensor_functions,
59
+ nn: nn_functions,
60
+ linalg: linalg_functions,
61
+ fft: fft_functions,
62
+ special: special_functions
63
+ }
45
64
  end
46
65
 
47
66
  def generate_files(type, def_method, functions)
@@ -111,11 +130,14 @@ def generate_attach_def(name, type, def_method)
111
130
  end
112
131
 
113
132
  ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
133
+ ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
134
+ ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
135
+ ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
114
136
 
115
137
  # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
116
138
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
117
139
 
118
- "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);"
140
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
119
141
  end
120
142
 
121
143
  def generate_method_def(name, functions, type, def_method)
@@ -128,7 +150,7 @@ def generate_method_def(name, functions, type, def_method)
128
150
 
129
151
  template = <<~EOS
130
152
  // #{name}
131
- static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
153
+ static VALUE #{full_name(name, type)}(int argc, VALUE* argv, VALUE self_)
132
154
  {
133
155
  HANDLE_TH_ERRORS#{assign_self}
134
156
  static RubyArgParser parser({
@@ -560,3 +582,11 @@ def signature_type(param)
560
582
  type += "?" if param[:optional]
561
583
  type
562
584
  end
585
+
586
+ def full_name(name, type)
587
+ if %w(fft linalg special).include?(type) && name.start_with?("#{type}_")
588
+ name
589
+ else
590
+ "#{type}_#{name}"
591
+ end
592
+ end
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/ext.cpp CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
5
7
  void init_nn(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
6
9
  void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
7
10
  void init_torch(Rice::Module& m);
8
11
 
12
+ void init_backends(Rice::Module& m);
9
13
  void init_cuda(Rice::Module& m);
10
14
  void init_device(Rice::Module& m);
11
15
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
@@ -27,7 +31,11 @@ void Init_ext()
27
31
  init_torch(m);
28
32
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
29
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
30
37
 
38
+ init_backends(m);
31
39
  init_cuda(m);
32
40
  init_device(m);
33
41
  init_ivalue(m, rb_cIValue);
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.0"
2
+ VERSION = "0.8.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.0
4
+ version: 0.8.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-06-15 00:00:00.000000000 Z
11
+ date: 2021-06-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -37,16 +37,23 @@ files:
37
37
  - codegen/function.rb
38
38
  - codegen/generate_functions.rb
39
39
  - codegen/native_functions.yaml
40
+ - ext/torch/backends.cpp
40
41
  - ext/torch/cuda.cpp
41
42
  - ext/torch/device.cpp
42
43
  - ext/torch/ext.cpp
43
44
  - ext/torch/extconf.rb
45
+ - ext/torch/fft.cpp
46
+ - ext/torch/fft_functions.h
44
47
  - ext/torch/ivalue.cpp
48
+ - ext/torch/linalg.cpp
49
+ - ext/torch/linalg_functions.h
45
50
  - ext/torch/nn.cpp
46
51
  - ext/torch/nn_functions.h
47
52
  - ext/torch/random.cpp
48
53
  - ext/torch/ruby_arg_parser.cpp
49
54
  - ext/torch/ruby_arg_parser.h
55
+ - ext/torch/special.cpp
56
+ - ext/torch/special_functions.h
50
57
  - ext/torch/templates.h
51
58
  - ext/torch/tensor.cpp
52
59
  - ext/torch/tensor_functions.h