torch-rb 0.1.8 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -2
- data/README.md +35 -11
- data/ext/torch/ext.cpp +37 -28
- data/ext/torch/extconf.rb +33 -6
- data/ext/torch/nn_functions.cpp +560 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +2 -0
- data/ext/torch/tensor_functions.cpp +2085 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +3175 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/hub.rb +9 -0
- data/lib/torch/native/generator.rb +6 -3
- data/lib/torch/native/native_functions.yaml +539 -397
- data/lib/torch/native/parser.rb +2 -0
- data/lib/torch/nn/adaptive_avg_pool1d.rb +9 -0
- data/lib/torch/nn/adaptive_avg_pool2d.rb +9 -0
- data/lib/torch/nn/adaptive_avg_pool3d.rb +9 -0
- data/lib/torch/nn/adaptive_avg_poolnd.rb +14 -0
- data/lib/torch/nn/adaptive_max_pool1d.rb +9 -0
- data/lib/torch/nn/adaptive_max_pool2d.rb +9 -0
- data/lib/torch/nn/adaptive_max_pool3d.rb +9 -0
- data/lib/torch/nn/adaptive_max_poolnd.rb +15 -0
- data/lib/torch/nn/functional.rb +40 -2
- data/lib/torch/nn/module.rb +22 -1
- data/lib/torch/optim/lr_scheduler/cosine_annealing_lr.rb +29 -0
- data/lib/torch/optim/lr_scheduler/exponential_lr.rb +22 -0
- data/lib/torch/optim/lr_scheduler/lambda_lr.rb +28 -0
- data/lib/torch/optim/lr_scheduler/multi_step_lr.rb +23 -0
- data/lib/torch/optim/lr_scheduler/multiplicative_lr.rb +32 -0
- data/lib/torch/tensor.rb +8 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +21 -0
- metadata +38 -3
Binary file
|
data/lib/torch/hub.rb
ADDED
@@ -18,14 +18,12 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
|
22
21
|
skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
|
23
22
|
|
24
23
|
# remove functions
|
25
24
|
functions.reject! do |f|
|
26
25
|
f.ruby_name.start_with?("_") ||
|
27
26
|
f.ruby_name.end_with?("_backward") ||
|
28
|
-
skip_binding.include?(f.ruby_name) ||
|
29
27
|
f.args.any? { |a| a[:type].include?("Dimname") }
|
30
28
|
end
|
31
29
|
|
@@ -34,7 +32,10 @@ module Torch
|
|
34
32
|
functions.partition do |f|
|
35
33
|
f.args.any? do |a|
|
36
34
|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
|
37
|
-
skip_args.any? { |sa| a[:type].include?(sa) }
|
35
|
+
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
+
# native_functions.yaml is missing size argument for normal
|
37
|
+
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
|
+
(f.base_name == "normal" && !f.out?)
|
38
39
|
end
|
39
40
|
end
|
40
41
|
|
@@ -119,6 +120,8 @@ void add_%{type}_functions(Module m) {
|
|
119
120
|
"IntArrayRef"
|
120
121
|
when /Tensor\(\S!?\)/
|
121
122
|
"Tensor &"
|
123
|
+
when "str"
|
124
|
+
"std::string"
|
122
125
|
else
|
123
126
|
a[:type]
|
124
127
|
end
|