torch-rb 0.14.1 → 0.16.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +4 -6
- data/codegen/native_functions.yaml +552 -118
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.h +0 -23
- data/ext/torch/tensor.cpp +1 -0
- data/ext/torch/utils.h +1 -1
- data/lib/torch/inspector.rb +8 -3
- data/lib/torch/nn/elu.rb +20 -0
- data/lib/torch/nn/functional.rb +12 -0
- data/lib/torch/nn/gelu.rb +18 -0
- data/lib/torch/nn/leaky_relu.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +2 -0
- metadata +6 -11
- data/ext/torch/fft_functions.h +0 -6
- data/ext/torch/linalg_functions.h +0 -6
- data/ext/torch/nn_functions.h +0 -6
- data/ext/torch/sparse_functions.h +0 -6
- data/ext/torch/special_functions.h +0 -6
- data/ext/torch/tensor_functions.h +0 -6
- data/ext/torch/torch_functions.h +0 -6
data/ext/torch/extconf.rb
CHANGED
@@ -52,6 +52,9 @@ $INCFLAGS += " -I#{inc}"
|
|
52
52
|
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
53
|
|
54
54
|
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
55
|
+
if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
|
56
|
+
$LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
|
57
|
+
end
|
55
58
|
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
data/ext/torch/templates.h
CHANGED
@@ -169,27 +169,4 @@ namespace Rice::detail
|
|
169
169
|
}
|
170
170
|
}
|
171
171
|
};
|
172
|
-
|
173
|
-
template<typename T>
|
174
|
-
struct Type<torch::optional<T>>
|
175
|
-
{
|
176
|
-
static bool verify()
|
177
|
-
{
|
178
|
-
return true;
|
179
|
-
}
|
180
|
-
};
|
181
|
-
|
182
|
-
template<typename T>
|
183
|
-
class From_Ruby<torch::optional<T>>
|
184
|
-
{
|
185
|
-
public:
|
186
|
-
torch::optional<T> convert(VALUE x)
|
187
|
-
{
|
188
|
-
if (NIL_P(x)) {
|
189
|
-
return torch::nullopt;
|
190
|
-
} else {
|
191
|
-
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
192
|
-
}
|
193
|
-
}
|
194
|
-
};
|
195
172
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
103
103
|
|
104
104
|
rb_cTensor
|
105
105
|
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
106
|
+
.define_method("mps?", [](Tensor& self) { return self.is_mps(); })
|
106
107
|
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
107
108
|
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
108
109
|
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/inspector.rb
CHANGED
@@ -31,9 +31,9 @@ module Torch
|
|
31
31
|
return if nonzero_finite_vals.numel == 0
|
32
32
|
|
33
33
|
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
-
nonzero_finite_abs = nonzero_finite_vals.abs
|
35
|
-
nonzero_finite_min = nonzero_finite_abs.min
|
36
|
-
nonzero_finite_max = nonzero_finite_abs.max
|
34
|
+
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
|
35
|
+
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
|
36
|
+
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)
|
37
37
|
|
38
38
|
nonzero_finite_vals.each do |value|
|
39
39
|
if value.item != value.item.ceil
|
@@ -107,6 +107,11 @@ module Torch
|
|
107
107
|
# Ruby throws error when negative, Python doesn't
|
108
108
|
" " * [@max_width - ret.size, 0].max + ret
|
109
109
|
end
|
110
|
+
|
111
|
+
def tensor_totype(t)
|
112
|
+
dtype = t.mps? ? :float : :double
|
113
|
+
t.to(dtype: dtype)
|
114
|
+
end
|
110
115
|
end
|
111
116
|
|
112
117
|
def inspect
|
data/lib/torch/nn/elu.rb
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ELU < Module
|
4
|
+
def initialize(alpha: 1, inplace: false)
|
5
|
+
super()
|
6
|
+
@alpha = alpha
|
7
|
+
@inplace = inplace
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.elu(input, alpha: @alpha, inplace: @inplace)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
inplace_str = @inplace ? ", inplace: true" : ""
|
16
|
+
format("alpha: %s", @alpha) + inplace_str
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -174,6 +174,18 @@ module Torch
|
|
174
174
|
|
175
175
|
# activation layers
|
176
176
|
|
177
|
+
def elu(input, alpha: 1, inplace: false)
|
178
|
+
if inplace
|
179
|
+
NN.elu!(input, alpha)
|
180
|
+
else
|
181
|
+
NN.elu(input, alpha)
|
182
|
+
end
|
183
|
+
end
|
184
|
+
|
185
|
+
def gelu(input, approximate: 'none')
|
186
|
+
NN.gelu(input, approximate: approximate)
|
187
|
+
end
|
188
|
+
|
177
189
|
def hardshrink(input, lambd = 0.5)
|
178
190
|
Torch.hardshrink(input, lambd)
|
179
191
|
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class GELU < Module
|
4
|
+
def initialize(approximate: 'none')
|
5
|
+
super()
|
6
|
+
@approximate = approximate
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.gelu(input, approximate: @approximate)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
"approximate: #{@approximate.inspect}"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -123,6 +123,8 @@ require_relative "torch/nn/dropout3d"
|
|
123
123
|
require_relative "torch/nn/feature_alpha_dropout"
|
124
124
|
|
125
125
|
# nn activations
|
126
|
+
require_relative "torch/nn/elu"
|
127
|
+
require_relative "torch/nn/gelu"
|
126
128
|
require_relative "torch/nn/hardshrink"
|
127
129
|
require_relative "torch/nn/leaky_relu"
|
128
130
|
require_relative "torch/nn/log_sigmoid"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.16.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-06-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -43,24 +43,17 @@ files:
|
|
43
43
|
- ext/torch/ext.cpp
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
|
-
- ext/torch/fft_functions.h
|
47
46
|
- ext/torch/generator.cpp
|
48
47
|
- ext/torch/ivalue.cpp
|
49
48
|
- ext/torch/linalg.cpp
|
50
|
-
- ext/torch/linalg_functions.h
|
51
49
|
- ext/torch/nn.cpp
|
52
|
-
- ext/torch/nn_functions.h
|
53
50
|
- ext/torch/random.cpp
|
54
51
|
- ext/torch/ruby_arg_parser.cpp
|
55
52
|
- ext/torch/ruby_arg_parser.h
|
56
|
-
- ext/torch/sparse_functions.h
|
57
53
|
- ext/torch/special.cpp
|
58
|
-
- ext/torch/special_functions.h
|
59
54
|
- ext/torch/templates.h
|
60
55
|
- ext/torch/tensor.cpp
|
61
|
-
- ext/torch/tensor_functions.h
|
62
56
|
- ext/torch/torch.cpp
|
63
|
-
- ext/torch/torch_functions.h
|
64
57
|
- ext/torch/utils.h
|
65
58
|
- ext/torch/wrap_outputs.h
|
66
59
|
- lib/torch-rb.rb
|
@@ -103,12 +96,14 @@ files:
|
|
103
96
|
- lib/torch/nn/dropout2d.rb
|
104
97
|
- lib/torch/nn/dropout3d.rb
|
105
98
|
- lib/torch/nn/dropoutnd.rb
|
99
|
+
- lib/torch/nn/elu.rb
|
106
100
|
- lib/torch/nn/embedding.rb
|
107
101
|
- lib/torch/nn/embedding_bag.rb
|
108
102
|
- lib/torch/nn/feature_alpha_dropout.rb
|
109
103
|
- lib/torch/nn/fold.rb
|
110
104
|
- lib/torch/nn/functional.rb
|
111
105
|
- lib/torch/nn/functional_attention.rb
|
106
|
+
- lib/torch/nn/gelu.rb
|
112
107
|
- lib/torch/nn/group_norm.rb
|
113
108
|
- lib/torch/nn/gru.rb
|
114
109
|
- lib/torch/nn/hardshrink.rb
|
@@ -230,14 +225,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
230
225
|
requirements:
|
231
226
|
- - ">="
|
232
227
|
- !ruby/object:Gem::Version
|
233
|
-
version: '3'
|
228
|
+
version: '3.1'
|
234
229
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
235
230
|
requirements:
|
236
231
|
- - ">="
|
237
232
|
- !ruby/object:Gem::Version
|
238
233
|
version: '0'
|
239
234
|
requirements: []
|
240
|
-
rubygems_version: 3.5.
|
235
|
+
rubygems_version: 3.5.11
|
241
236
|
signing_key:
|
242
237
|
specification_version: 4
|
243
238
|
summary: Deep learning for Ruby, powered by LibTorch
|
data/ext/torch/fft_functions.h
DELETED
data/ext/torch/nn_functions.h
DELETED