torch-rb 0.3.5 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/torch/ext.cpp +0 -10
- data/ext/torch/templates.cpp +8 -0
- data/ext/torch/templates.hpp +46 -24
- data/lib/torch.rb +16 -19
- data/lib/torch/native/function.rb +4 -0
- data/lib/torch/native/generator.rb +10 -5
- data/lib/torch/native/parser.rb +8 -0
- data/lib/torch/nn/functional.rb +6 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +6 -1
- data/lib/torch/tensor.rb +1 -6
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
|
4
|
+
data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
|
7
|
+
data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
## 0.3.6 (2020-09-17)
|
2
|
+
|
3
|
+
- Added `inplace` option for leaky ReLU
|
4
|
+
- Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
|
5
|
+
- Fixed error with buffers on GPU
|
6
|
+
|
1
7
|
## 0.3.5 (2020-09-04)
|
2
8
|
|
3
9
|
- Fixed error with data loader (due to `dtype` of `randperm`)
|
data/ext/torch/ext.cpp
CHANGED
@@ -301,11 +301,6 @@ void Init_ext()
|
|
301
301
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
302
302
|
return torch::pickle_load(v);
|
303
303
|
})
|
304
|
-
.define_singleton_method(
|
305
|
-
"_binary_cross_entropy_with_logits",
|
306
|
-
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
307
|
-
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
308
|
-
})
|
309
304
|
.define_singleton_method(
|
310
305
|
"_from_blob",
|
311
306
|
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
@@ -379,11 +374,6 @@ void Init_ext()
|
|
379
374
|
*[](Tensor& self, bool requires_grad) {
|
380
375
|
return self.set_requires_grad(requires_grad);
|
381
376
|
})
|
382
|
-
.define_method(
|
383
|
-
"_backward",
|
384
|
-
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
385
|
-
return self.backward(gradient, create_graph, retain_graph);
|
386
|
-
})
|
387
377
|
.define_method(
|
388
378
|
"grad",
|
389
379
|
*[](Tensor& self) {
|
data/ext/torch/templates.cpp
CHANGED
@@ -53,3 +53,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
53
53
|
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
54
|
return Object(a);
|
55
55
|
}
|
56
|
+
|
57
|
+
Object wrap(std::vector<torch::Tensor> x) {
|
58
|
+
Array a;
|
59
|
+
for (auto& t : x) {
|
60
|
+
a.push(to_ruby<torch::Tensor>(t));
|
61
|
+
}
|
62
|
+
return Object(a);
|
63
|
+
}
|
data/ext/torch/templates.hpp
CHANGED
@@ -10,6 +10,7 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
12
|
using torch::Device;
|
13
|
+
using torch::Scalar;
|
13
14
|
using torch::ScalarType;
|
14
15
|
using torch::Tensor;
|
15
16
|
|
@@ -36,30 +37,6 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
|
|
36
37
|
return IntArrayRef(x);
|
37
38
|
}
|
38
39
|
|
39
|
-
// for now
|
40
|
-
class Scalar {
|
41
|
-
torch::Scalar value;
|
42
|
-
public:
|
43
|
-
Scalar(Object o) {
|
44
|
-
// TODO cast based on Ruby type
|
45
|
-
if (o.rb_type() == T_FIXNUM) {
|
46
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
47
|
-
} else {
|
48
|
-
value = torch::Scalar(from_ruby<float>(o));
|
49
|
-
}
|
50
|
-
}
|
51
|
-
operator torch::Scalar() {
|
52
|
-
return value;
|
53
|
-
}
|
54
|
-
};
|
55
|
-
|
56
|
-
template<>
|
57
|
-
inline
|
58
|
-
Scalar from_ruby<Scalar>(Object x)
|
59
|
-
{
|
60
|
-
return Scalar(x);
|
61
|
-
}
|
62
|
-
|
63
40
|
class TensorList {
|
64
41
|
std::vector<torch::Tensor> vec;
|
65
42
|
public:
|
@@ -192,6 +169,17 @@ class OptionalTensor {
|
|
192
169
|
}
|
193
170
|
};
|
194
171
|
|
172
|
+
template<>
|
173
|
+
inline
|
174
|
+
Scalar from_ruby<Scalar>(Object x)
|
175
|
+
{
|
176
|
+
if (x.rb_type() == T_FIXNUM) {
|
177
|
+
return torch::Scalar(from_ruby<int64_t>(x));
|
178
|
+
} else {
|
179
|
+
return torch::Scalar(from_ruby<double>(x));
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
195
183
|
template<>
|
196
184
|
inline
|
197
185
|
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
@@ -221,9 +209,43 @@ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
|
221
209
|
}
|
222
210
|
}
|
223
211
|
|
212
|
+
template<>
|
213
|
+
inline
|
214
|
+
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
|
215
|
+
{
|
216
|
+
if (x.is_nil()) {
|
217
|
+
return torch::nullopt;
|
218
|
+
} else {
|
219
|
+
return torch::optional<double>{from_ruby<double>(x)};
|
220
|
+
}
|
221
|
+
}
|
222
|
+
|
223
|
+
template<>
|
224
|
+
inline
|
225
|
+
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
226
|
+
{
|
227
|
+
if (x.is_nil()) {
|
228
|
+
return torch::nullopt;
|
229
|
+
} else {
|
230
|
+
return torch::optional<bool>{from_ruby<bool>(x)};
|
231
|
+
}
|
232
|
+
}
|
233
|
+
|
234
|
+
template<>
|
235
|
+
inline
|
236
|
+
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
237
|
+
{
|
238
|
+
if (x.is_nil()) {
|
239
|
+
return torch::nullopt;
|
240
|
+
} else {
|
241
|
+
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
242
|
+
}
|
243
|
+
}
|
244
|
+
|
224
245
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
225
246
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
226
247
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
227
248
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
228
249
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
|
229
250
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
|
251
|
+
Object wrap(std::vector<torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -239,25 +239,22 @@ module Torch
|
|
239
239
|
cls
|
240
240
|
end
|
241
241
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
259
|
-
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
260
|
-
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
242
|
+
DTYPE_TO_CLASS = {
|
243
|
+
float32: "FloatTensor",
|
244
|
+
float64: "DoubleTensor",
|
245
|
+
float16: "HalfTensor",
|
246
|
+
uint8: "ByteTensor",
|
247
|
+
int8: "CharTensor",
|
248
|
+
int16: "ShortTensor",
|
249
|
+
int32: "IntTensor",
|
250
|
+
int64: "LongTensor",
|
251
|
+
bool: "BoolTensor"
|
252
|
+
}
|
253
|
+
|
254
|
+
DTYPE_TO_CLASS.each do |dtype, class_name|
|
255
|
+
const_set(class_name, _make_tensor_class(dtype))
|
256
|
+
CUDA.const_set(class_name, _make_tensor_class(dtype, true))
|
257
|
+
end
|
261
258
|
|
262
259
|
class << self
|
263
260
|
# Torch.float, Torch.long, etc
|
@@ -18,12 +18,12 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["
|
21
|
+
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
25
25
|
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.
|
26
|
+
f.ruby_name.include?("_backward") ||
|
27
27
|
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
28
|
end
|
29
29
|
|
@@ -31,7 +31,6 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
34
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
35
|
# call to 'range' is ambiguous
|
37
36
|
f.cpp_name == "_range" ||
|
@@ -104,6 +103,12 @@ void add_%{type}_functions(Module m) {
|
|
104
103
|
"int64_t"
|
105
104
|
when "int?"
|
106
105
|
"torch::optional<int64_t>"
|
106
|
+
when "float?"
|
107
|
+
"torch::optional<double>"
|
108
|
+
when "bool?"
|
109
|
+
"torch::optional<bool>"
|
110
|
+
when "Scalar?"
|
111
|
+
"torch::optional<torch::Scalar>"
|
107
112
|
when "float"
|
108
113
|
"double"
|
109
114
|
when /\Aint\[/
|
@@ -130,8 +135,8 @@ void add_%{type}_functions(Module m) {
|
|
130
135
|
prefix = def_method == :define_method ? "self." : "torch::"
|
131
136
|
|
132
137
|
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
133
|
-
|
134
|
-
if func.ret_size > 1
|
138
|
+
|
139
|
+
if func.ret_size > 1 || func.ret_array?
|
135
140
|
body = "wrap(#{body})"
|
136
141
|
end
|
137
142
|
|
data/lib/torch/native/parser.rb
CHANGED
@@ -85,6 +85,10 @@ module Torch
|
|
85
85
|
end
|
86
86
|
when "int?"
|
87
87
|
v.is_a?(Integer) || v.nil?
|
88
|
+
when "float?"
|
89
|
+
v.is_a?(Numeric) || v.nil?
|
90
|
+
when "bool?"
|
91
|
+
v == true || v == false || v.nil?
|
88
92
|
when "float"
|
89
93
|
v.is_a?(Numeric)
|
90
94
|
when /int\[.*\]/
|
@@ -97,6 +101,10 @@ module Torch
|
|
97
101
|
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
98
102
|
when "Scalar"
|
99
103
|
v.is_a?(Numeric)
|
104
|
+
when "Scalar?"
|
105
|
+
v.is_a?(Numeric) || v.nil?
|
106
|
+
when "ScalarType"
|
107
|
+
false # not supported yet
|
100
108
|
when "ScalarType?"
|
101
109
|
v.nil?
|
102
110
|
when "bool"
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -178,8 +178,12 @@ module Torch
|
|
178
178
|
Torch.hardshrink(input, lambd)
|
179
179
|
end
|
180
180
|
|
181
|
-
def leaky_relu(input, negative_slope = 0.01)
|
182
|
-
|
181
|
+
def leaky_relu(input, negative_slope = 0.01, inplace: false)
|
182
|
+
if inplace
|
183
|
+
NN.leaky_relu!(input, negative_slope)
|
184
|
+
else
|
185
|
+
NN.leaky_relu(input, negative_slope)
|
186
|
+
end
|
183
187
|
end
|
184
188
|
|
185
189
|
def log_sigmoid(input)
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class LeakyReLU < Module
|
4
|
-
def initialize(negative_slope: 1e-2
|
4
|
+
def initialize(negative_slope: 1e-2, inplace: false)
|
5
5
|
super()
|
6
6
|
@negative_slope = negative_slope
|
7
|
-
|
7
|
+
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
10
|
def forward(input)
|
11
|
-
F.leaky_relu(input, @negative_slope
|
11
|
+
F.leaky_relu(input, @negative_slope, inplace: @inplace)
|
12
12
|
end
|
13
13
|
|
14
14
|
def extra_inspect
|
data/lib/torch/nn/module.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -103,11 +103,6 @@ module Torch
|
|
103
103
|
Torch.empty(0, dtype: dtype)
|
104
104
|
end
|
105
105
|
|
106
|
-
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
-
retain_graph = create_graph if retain_graph.nil?
|
108
|
-
_backward(gradient, retain_graph, create_graph)
|
109
|
-
end
|
110
|
-
|
111
106
|
# TODO read directly from memory
|
112
107
|
def numo
|
113
108
|
cls = Torch._dtype_to_numo[dtype]
|
@@ -235,7 +230,7 @@ module Torch
|
|
235
230
|
when Integer
|
236
231
|
TensorIndex.integer(index)
|
237
232
|
when Range
|
238
|
-
finish = index.end
|
233
|
+
finish = index.end || -1
|
239
234
|
if finish == -1 && !index.exclude_end?
|
240
235
|
finish = nil
|
241
236
|
else
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-09-
|
11
|
+
date: 2020-09-18 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|