torch-rb 0.14.1 → 0.16.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +4 -6
- data/codegen/native_functions.yaml +552 -118
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.h +0 -23
- data/ext/torch/tensor.cpp +1 -0
- data/ext/torch/utils.h +1 -1
- data/lib/torch/inspector.rb +8 -3
- data/lib/torch/nn/elu.rb +20 -0
- data/lib/torch/nn/functional.rb +12 -0
- data/lib/torch/nn/gelu.rb +18 -0
- data/lib/torch/nn/leaky_relu.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +2 -0
- metadata +6 -11
- data/ext/torch/fft_functions.h +0 -6
- data/ext/torch/linalg_functions.h +0 -6
- data/ext/torch/nn_functions.h +0 -6
- data/ext/torch/sparse_functions.h +0 -6
- data/ext/torch/special_functions.h +0 -6
- data/ext/torch/tensor_functions.h +0 -6
- data/ext/torch/torch_functions.h +0 -6
data/ext/torch/extconf.rb
CHANGED
@@ -52,6 +52,9 @@ $INCFLAGS += " -I#{inc}"
|
|
52
52
|
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
53
|
|
54
54
|
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
55
|
+
if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
|
56
|
+
$LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
|
57
|
+
end
|
55
58
|
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
data/ext/torch/templates.h
CHANGED
@@ -169,27 +169,4 @@ namespace Rice::detail
|
|
169
169
|
}
|
170
170
|
}
|
171
171
|
};
|
172
|
-
|
173
|
-
template<typename T>
|
174
|
-
struct Type<torch::optional<T>>
|
175
|
-
{
|
176
|
-
static bool verify()
|
177
|
-
{
|
178
|
-
return true;
|
179
|
-
}
|
180
|
-
};
|
181
|
-
|
182
|
-
template<typename T>
|
183
|
-
class From_Ruby<torch::optional<T>>
|
184
|
-
{
|
185
|
-
public:
|
186
|
-
torch::optional<T> convert(VALUE x)
|
187
|
-
{
|
188
|
-
if (NIL_P(x)) {
|
189
|
-
return torch::nullopt;
|
190
|
-
} else {
|
191
|
-
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
192
|
-
}
|
193
|
-
}
|
194
|
-
};
|
195
172
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
103
103
|
|
104
104
|
rb_cTensor
|
105
105
|
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
106
|
+
.define_method("mps?", [](Tensor& self) { return self.is_mps(); })
|
106
107
|
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
107
108
|
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
108
109
|
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/inspector.rb
CHANGED
@@ -31,9 +31,9 @@ module Torch
|
|
31
31
|
return if nonzero_finite_vals.numel == 0
|
32
32
|
|
33
33
|
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
-
nonzero_finite_abs = nonzero_finite_vals.abs
|
35
|
-
nonzero_finite_min = nonzero_finite_abs.min
|
36
|
-
nonzero_finite_max = nonzero_finite_abs.max
|
34
|
+
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
|
35
|
+
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
|
36
|
+
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)
|
37
37
|
|
38
38
|
nonzero_finite_vals.each do |value|
|
39
39
|
if value.item != value.item.ceil
|
@@ -107,6 +107,11 @@ module Torch
|
|
107
107
|
# Ruby throws error when negative, Python doesn't
|
108
108
|
" " * [@max_width - ret.size, 0].max + ret
|
109
109
|
end
|
110
|
+
|
111
|
+
def tensor_totype(t)
|
112
|
+
dtype = t.mps? ? :float : :double
|
113
|
+
t.to(dtype: dtype)
|
114
|
+
end
|
110
115
|
end
|
111
116
|
|
112
117
|
def inspect
|
data/lib/torch/nn/elu.rb
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ELU < Module
|
4
|
+
def initialize(alpha: 1, inplace: false)
|
5
|
+
super()
|
6
|
+
@alpha = alpha
|
7
|
+
@inplace = inplace
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.elu(input, alpha: @alpha, inplace: @inplace)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
inplace_str = @inplace ? ", inplace: true" : ""
|
16
|
+
format("alpha: %s", @alpha) + inplace_str
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -174,6 +174,18 @@ module Torch
|
|
174
174
|
|
175
175
|
# activation layers
|
176
176
|
|
177
|
+
def elu(input, alpha: 1, inplace: false)
|
178
|
+
if inplace
|
179
|
+
NN.elu!(input, alpha)
|
180
|
+
else
|
181
|
+
NN.elu(input, alpha)
|
182
|
+
end
|
183
|
+
end
|
184
|
+
|
185
|
+
def gelu(input, approximate: 'none')
|
186
|
+
NN.gelu(input, approximate: approximate)
|
187
|
+
end
|
188
|
+
|
177
189
|
def hardshrink(input, lambd = 0.5)
|
178
190
|
Torch.hardshrink(input, lambd)
|
179
191
|
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class GELU < Module
|
4
|
+
def initialize(approximate: 'none')
|
5
|
+
super()
|
6
|
+
@approximate = approximate
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.gelu(input, approximate: @approximate)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
"approximate: #{@approximate.inspect}"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -123,6 +123,8 @@ require_relative "torch/nn/dropout3d"
|
|
123
123
|
require_relative "torch/nn/feature_alpha_dropout"
|
124
124
|
|
125
125
|
# nn activations
|
126
|
+
require_relative "torch/nn/elu"
|
127
|
+
require_relative "torch/nn/gelu"
|
126
128
|
require_relative "torch/nn/hardshrink"
|
127
129
|
require_relative "torch/nn/leaky_relu"
|
128
130
|
require_relative "torch/nn/log_sigmoid"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.16.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-06-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -43,24 +43,17 @@ files:
|
|
43
43
|
- ext/torch/ext.cpp
|
44
44
|
- ext/torch/extconf.rb
|
45
45
|
- ext/torch/fft.cpp
|
46
|
-
- ext/torch/fft_functions.h
|
47
46
|
- ext/torch/generator.cpp
|
48
47
|
- ext/torch/ivalue.cpp
|
49
48
|
- ext/torch/linalg.cpp
|
50
|
-
- ext/torch/linalg_functions.h
|
51
49
|
- ext/torch/nn.cpp
|
52
|
-
- ext/torch/nn_functions.h
|
53
50
|
- ext/torch/random.cpp
|
54
51
|
- ext/torch/ruby_arg_parser.cpp
|
55
52
|
- ext/torch/ruby_arg_parser.h
|
56
|
-
- ext/torch/sparse_functions.h
|
57
53
|
- ext/torch/special.cpp
|
58
|
-
- ext/torch/special_functions.h
|
59
54
|
- ext/torch/templates.h
|
60
55
|
- ext/torch/tensor.cpp
|
61
|
-
- ext/torch/tensor_functions.h
|
62
56
|
- ext/torch/torch.cpp
|
63
|
-
- ext/torch/torch_functions.h
|
64
57
|
- ext/torch/utils.h
|
65
58
|
- ext/torch/wrap_outputs.h
|
66
59
|
- lib/torch-rb.rb
|
@@ -103,12 +96,14 @@ files:
|
|
103
96
|
- lib/torch/nn/dropout2d.rb
|
104
97
|
- lib/torch/nn/dropout3d.rb
|
105
98
|
- lib/torch/nn/dropoutnd.rb
|
99
|
+
- lib/torch/nn/elu.rb
|
106
100
|
- lib/torch/nn/embedding.rb
|
107
101
|
- lib/torch/nn/embedding_bag.rb
|
108
102
|
- lib/torch/nn/feature_alpha_dropout.rb
|
109
103
|
- lib/torch/nn/fold.rb
|
110
104
|
- lib/torch/nn/functional.rb
|
111
105
|
- lib/torch/nn/functional_attention.rb
|
106
|
+
- lib/torch/nn/gelu.rb
|
112
107
|
- lib/torch/nn/group_norm.rb
|
113
108
|
- lib/torch/nn/gru.rb
|
114
109
|
- lib/torch/nn/hardshrink.rb
|
@@ -230,14 +225,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
230
225
|
requirements:
|
231
226
|
- - ">="
|
232
227
|
- !ruby/object:Gem::Version
|
233
|
-
version: '3'
|
228
|
+
version: '3.1'
|
234
229
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
235
230
|
requirements:
|
236
231
|
- - ">="
|
237
232
|
- !ruby/object:Gem::Version
|
238
233
|
version: '0'
|
239
234
|
requirements: []
|
240
|
-
rubygems_version: 3.5.
|
235
|
+
rubygems_version: 3.5.11
|
241
236
|
signing_key:
|
242
237
|
specification_version: 4
|
243
238
|
summary: Deep learning for Ruby, powered by LibTorch
|
data/ext/torch/fft_functions.h
DELETED
data/ext/torch/nn_functions.h
DELETED