torch-rb 0.8.0 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fdc46d7eba97851841de58e63b35badbd579ce1bc2a3e311998db03956cb84ad
4
- data.tar.gz: 944520049bd683d303e4587d18394eeb13102aa9f4addeca87875ee214806de4
3
+ metadata.gz: d52cf2bf4770e9166623614f6071e5180e9492b0063757be8bdec73a2c930b38
4
+ data.tar.gz: 1b650a3277d1aebe28cdd5d75ce54420feba7a1c9a2046335c9271c72eb3f74f
5
5
  SHA512:
6
- metadata.gz: 9026cbd9b0cdf286e79e7b296f2fae95712bd553c2cd58a34044feeed482f1ed2630b7b1e852a82e87d1011478b7e24d442d634f07de67b2c0a6a8e1cf6ade9a
7
- data.tar.gz: 193e25fc5065d42160b0584f3e53512606c1b4329c6d872b90a972d7c4d46a4be81eef8c709e3060339c96758fe504dc8618da9e69ff411d9351d528983da72b
6
+ metadata.gz: 3afad67c5ca6cedc4925dab1aadc37541daac6752a8b508a8d4bd6cd3b25b71600ed32625a65651bf150d6ddd519ec86ca7b9ccd2f12022366ed692603f65c1a
7
+ data.tar.gz: 970fa451044ce68d60e13f2da297ea8e20e53f01a1ffee31f152f190baf5d8f709f9be3defcd6b24739499cb400c5448abd43939a451be826e87b2e62eba3cac
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ ## 0.8.1 (2021-06-15)
2
+
3
+ - Added `Backends` module
4
+ - Added `FFT` module
5
+ - Added `Linalg` module
6
+ - Added `Special` module
7
+
1
8
  ## 0.8.0 (2021-06-15)
2
9
 
3
10
  - Updated LibTorch to 1.9.0
@@ -11,6 +11,9 @@ def generate_functions
11
11
  generate_files("torch", :define_singleton_method, functions[:torch])
12
12
  generate_files("tensor", :define_method, functions[:tensor])
13
13
  generate_files("nn", :define_singleton_method, functions[:nn])
14
+ generate_files("fft", :define_singleton_method, functions[:fft])
15
+ generate_files("linalg", :define_singleton_method, functions[:linalg])
16
+ generate_files("special", :define_singleton_method, functions[:special])
14
17
  end
15
18
 
16
19
  def load_functions
@@ -38,10 +41,26 @@ end
38
41
 
39
42
  def group_functions(functions)
40
43
  nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
44
+ linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" }
45
+ fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
46
+ special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
47
+ unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
41
48
  torch_functions = other_functions.select { |f| f.variants.include?("function") }
42
49
  tensor_functions = other_functions.select { |f| f.variants.include?("method") }
43
50
 
44
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
51
+ if unexpected_functions.any?
52
+ unexpected_modules = unexpected_functions.map(&:python_module).uniq
53
+ raise "Unexpected modules: #{unexpected_modules.join(", ")}"
54
+ end
55
+
56
+ {
57
+ torch: torch_functions,
58
+ tensor: tensor_functions,
59
+ nn: nn_functions,
60
+ linalg: linalg_functions,
61
+ fft: fft_functions,
62
+ special: special_functions
63
+ }
45
64
  end
46
65
 
47
66
  def generate_files(type, def_method, functions)
@@ -111,11 +130,14 @@ def generate_attach_def(name, type, def_method)
111
130
  end
112
131
 
113
132
  ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
133
+ ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
134
+ ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
135
+ ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
114
136
 
115
137
  # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
116
138
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
117
139
 
118
- "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);"
140
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
119
141
  end
120
142
 
121
143
  def generate_method_def(name, functions, type, def_method)
@@ -128,7 +150,7 @@ def generate_method_def(name, functions, type, def_method)
128
150
 
129
151
  template = <<~EOS
130
152
  // #{name}
131
- static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
153
+ static VALUE #{full_name(name, type)}(int argc, VALUE* argv, VALUE self_)
132
154
  {
133
155
  HANDLE_TH_ERRORS#{assign_self}
134
156
  static RubyArgParser parser({
@@ -560,3 +582,11 @@ def signature_type(param)
560
582
  type += "?" if param[:optional]
561
583
  type
562
584
  end
585
+
586
+ def full_name(name, type)
587
+ if %w(fft linalg special).include?(type) && name.start_with?("#{type}_")
588
+ name
589
+ else
590
+ "#{type}_#{name}"
591
+ end
592
+ end
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/ext.cpp CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
5
7
  void init_nn(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
6
9
  void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
7
10
  void init_torch(Rice::Module& m);
8
11
 
12
+ void init_backends(Rice::Module& m);
9
13
  void init_cuda(Rice::Module& m);
10
14
  void init_device(Rice::Module& m);
11
15
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
@@ -27,7 +31,11 @@ void Init_ext()
27
31
  init_torch(m);
28
32
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
29
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
30
37
 
38
+ init_backends(m);
31
39
  init_cuda(m);
32
40
  init_device(m);
33
41
  init_ivalue(m, rb_cIValue);
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.8.0"
2
+ VERSION = "0.8.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.0
4
+ version: 0.8.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-06-15 00:00:00.000000000 Z
11
+ date: 2021-06-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -37,16 +37,23 @@ files:
37
37
  - codegen/function.rb
38
38
  - codegen/generate_functions.rb
39
39
  - codegen/native_functions.yaml
40
+ - ext/torch/backends.cpp
40
41
  - ext/torch/cuda.cpp
41
42
  - ext/torch/device.cpp
42
43
  - ext/torch/ext.cpp
43
44
  - ext/torch/extconf.rb
45
+ - ext/torch/fft.cpp
46
+ - ext/torch/fft_functions.h
44
47
  - ext/torch/ivalue.cpp
48
+ - ext/torch/linalg.cpp
49
+ - ext/torch/linalg_functions.h
45
50
  - ext/torch/nn.cpp
46
51
  - ext/torch/nn_functions.h
47
52
  - ext/torch/random.cpp
48
53
  - ext/torch/ruby_arg_parser.cpp
49
54
  - ext/torch/ruby_arg_parser.h
55
+ - ext/torch/special.cpp
56
+ - ext/torch/special_functions.h
50
57
  - ext/torch/templates.h
51
58
  - ext/torch/tensor.cpp
52
59
  - ext/torch/tensor_functions.h