torch-rb 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +5 -1
- data/ext/torch/ext.cpp +31 -1
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +11 -0
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/tensor.rb +32 -39
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f17982ebcf982b779b2c88dc0ef13a9c94b2e9a6a87007a0c3136ad5f2ef261b
|
4
|
+
data.tar.gz: d4ddddfc7cbb9baee0e117e36decfe8a757903f876ddcf510db4db53363f7adf
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 00ff39bb405350f0107974b1786bf14ac6ca558744f05653ee2f16445d46f66658f82dcfa069bd3a92b44c33ba642dd196fc29a21379f188111fd7e8648f5eab
|
7
|
+
data.tar.gz: b66d48682789f71032c1d928174829791df43a04e829ebc241cafab884bc076983323bdffbcb12f5785c6d4614c13d2fed32bb5f20106957a4d6326289b7a1ea
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -58,6 +68,13 @@ void Init_ext()
|
|
58
68
|
return generator.seed();
|
59
69
|
});
|
60
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
61
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
62
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -330,6 +347,18 @@ void Init_ext()
|
|
330
347
|
.define_method("numel", &torch::Tensor::numel)
|
331
348
|
.define_method("element_size", &torch::Tensor::element_size)
|
332
349
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
350
|
+
.define_method(
|
351
|
+
"_index",
|
352
|
+
*[](Tensor& self, Array indices) {
|
353
|
+
auto vec = index_vector(indices);
|
354
|
+
return self.index(vec);
|
355
|
+
})
|
356
|
+
.define_method(
|
357
|
+
"_index_put_custom",
|
358
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
359
|
+
auto vec = index_vector(indices);
|
360
|
+
return self.index_put_(vec, value);
|
361
|
+
})
|
333
362
|
.define_method(
|
334
363
|
"contiguous?",
|
335
364
|
*[](Tensor& self) {
|
@@ -508,6 +537,7 @@ void Init_ext()
|
|
508
537
|
});
|
509
538
|
|
510
539
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
540
|
+
.add_handler<torch::Error>(handle_error)
|
511
541
|
.define_singleton_method(
|
512
542
|
"_calculate_gain",
|
513
543
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -594,8 +624,8 @@ void Init_ext()
|
|
594
624
|
});
|
595
625
|
|
596
626
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
597
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
598
627
|
.add_handler<torch::Error>(handle_error)
|
628
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
599
629
|
.define_method("index", &torch::Device::index)
|
600
630
|
.define_method("index?", &torch::Device::has_index)
|
601
631
|
.define_method(
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,10 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::ScalarType;
|
14
|
+
using torch::Tensor;
|
15
|
+
|
12
16
|
// need to wrap torch::IntArrayRef() since
|
13
17
|
// it doesn't own underlying data
|
14
18
|
class IntArrayRef {
|
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
178
|
return MyReduction(x);
|
175
179
|
}
|
176
180
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
181
|
class OptionalTensor {
|
180
182
|
Object value;
|
181
183
|
public:
|
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
199
|
return OptionalTensor(x);
|
198
200
|
}
|
199
201
|
|
200
|
-
class ScalarType {
|
201
|
-
Object value;
|
202
|
-
public:
|
203
|
-
ScalarType(Object o) {
|
204
|
-
value = o;
|
205
|
-
}
|
206
|
-
operator at::ScalarType() {
|
207
|
-
throw std::runtime_error("ScalarType arguments not implemented yet");
|
208
|
-
}
|
209
|
-
};
|
210
|
-
|
211
202
|
template<>
|
212
203
|
inline
|
213
|
-
ScalarType from_ruby<ScalarType
|
204
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
214
205
|
{
|
215
|
-
|
206
|
+
if (x.is_nil()) {
|
207
|
+
return torch::nullopt;
|
208
|
+
} else {
|
209
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
210
|
+
}
|
216
211
|
}
|
217
212
|
|
218
|
-
class OptionalScalarType {
|
219
|
-
Object value;
|
220
|
-
public:
|
221
|
-
OptionalScalarType(Object o) {
|
222
|
-
value = o;
|
223
|
-
}
|
224
|
-
operator c10::optional<at::ScalarType>() {
|
225
|
-
if (value.is_nil()) {
|
226
|
-
return c10::nullopt;
|
227
|
-
}
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
231
|
-
|
232
213
|
template<>
|
233
214
|
inline
|
234
|
-
|
215
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
235
216
|
{
|
236
|
-
|
217
|
+
if (x.is_nil()) {
|
218
|
+
return torch::nullopt;
|
219
|
+
} else {
|
220
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
221
|
+
}
|
237
222
|
}
|
238
223
|
|
239
|
-
typedef torch::Device Device;
|
240
|
-
|
241
224
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
225
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
226
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -460,6 +460,17 @@ module Torch
|
|
460
460
|
zeros(input.size, **like_options(input, options))
|
461
461
|
end
|
462
462
|
|
463
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
464
|
+
if center
|
465
|
+
signal_dim = input.dim
|
466
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
467
|
+
pad = n_fft.div(2).to_i
|
468
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
469
|
+
input = input.view(input.shape[-signal_dim..-1])
|
470
|
+
end
|
471
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
472
|
+
end
|
473
|
+
|
463
474
|
private
|
464
475
|
|
465
476
|
def to_ivalue(obj)
|
@@ -1,10 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
+
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
+
@function["func"].sub!(tensor_options_str, ")")
|
8
12
|
end
|
9
13
|
|
10
14
|
def func
|
@@ -33,30 +33,14 @@ module Torch
|
|
33
33
|
f.args.any? do |a|
|
34
34
|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
+
# call to 'range' is ambiguous
|
37
|
+
f.cpp_name == "_range" ||
|
36
38
|
# native_functions.yaml is missing size argument for normal
|
37
39
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
40
|
(f.base_name == "normal" && !f.out?)
|
39
41
|
end
|
40
42
|
end
|
41
43
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
44
|
# todo_functions.each do |f|
|
61
45
|
# puts f.func
|
62
46
|
# puts
|
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
|
|
97
81
|
|
98
82
|
cpp_defs = []
|
99
83
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
84
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
85
|
+
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
101
86
|
|
102
87
|
cpp_args = []
|
103
88
|
fargs.each do |a|
|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
|
|
109
94
|
# TODO better signature
|
110
95
|
"OptionalTensor"
|
111
96
|
when "ScalarType?"
|
112
|
-
"
|
97
|
+
"torch::optional<ScalarType>"
|
113
98
|
when "Tensor[]"
|
114
99
|
"TensorList"
|
115
100
|
when "Tensor?[]"
|
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
|
|
117
102
|
"TensorList"
|
118
103
|
when "int"
|
119
104
|
"int64_t"
|
105
|
+
when "int?"
|
106
|
+
"torch::optional<int64_t>"
|
120
107
|
when "float"
|
121
108
|
"double"
|
122
109
|
when /\Aint\[/
|
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
|
|
125
112
|
"Tensor &"
|
126
113
|
when "str"
|
127
114
|
"std::string"
|
115
|
+
when "TensorOptions"
|
116
|
+
"const torch::TensorOptions &"
|
128
117
|
else
|
129
118
|
a[:type]
|
130
119
|
end
|
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,8 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
86
88
|
when "float"
|
87
89
|
v.is_a?(Numeric)
|
88
90
|
when /int\[.*\]/
|
@@ -126,9 +128,11 @@ module Torch
|
|
126
128
|
end
|
127
129
|
|
128
130
|
func = candidates.first
|
131
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
133
|
{
|
130
134
|
name: func.cpp_name,
|
131
|
-
args:
|
135
|
+
args: args
|
132
136
|
}
|
133
137
|
end
|
134
138
|
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -188,49 +188,15 @@ module Torch
|
|
188
188
|
# based on python_variable_indexing.cpp and
|
189
189
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
190
|
def [](*indexes)
|
191
|
-
|
192
|
-
dim = 0
|
193
|
-
indexes.each do |index|
|
194
|
-
if index.is_a?(Numeric)
|
195
|
-
result = result._select_int(dim, index)
|
196
|
-
elsif index.is_a?(Range)
|
197
|
-
finish = index.end
|
198
|
-
finish += 1 unless index.exclude_end?
|
199
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
200
|
-
dim += 1
|
201
|
-
elsif index.is_a?(Tensor)
|
202
|
-
result = result.index([index])
|
203
|
-
elsif index.nil?
|
204
|
-
result = result.unsqueeze(dim)
|
205
|
-
dim += 1
|
206
|
-
elsif index == true
|
207
|
-
result = result.unsqueeze(dim)
|
208
|
-
# TODO handle false
|
209
|
-
else
|
210
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
211
|
-
end
|
212
|
-
end
|
213
|
-
result
|
191
|
+
_index(tensor_indexes(indexes))
|
214
192
|
end
|
215
193
|
|
216
194
|
# based on python_variable_indexing.cpp and
|
217
195
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
218
|
-
def []=(
|
196
|
+
def []=(*indexes, value)
|
219
197
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
220
|
-
|
221
198
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
222
|
-
|
223
|
-
if index.is_a?(Numeric)
|
224
|
-
index_put!([Torch.tensor(index)], value)
|
225
|
-
elsif index.is_a?(Range)
|
226
|
-
finish = index.end
|
227
|
-
finish += 1 unless index.exclude_end?
|
228
|
-
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
229
|
-
elsif index.is_a?(Tensor)
|
230
|
-
index_put!([index], value)
|
231
|
-
else
|
232
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
233
|
-
end
|
199
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
234
200
|
end
|
235
201
|
|
236
202
|
# native functions that need manually defined
|
@@ -244,13 +210,13 @@ module Torch
|
|
244
210
|
end
|
245
211
|
end
|
246
212
|
|
247
|
-
#
|
213
|
+
# parser can't handle overlap, so need to handle manually
|
248
214
|
def random!(*args)
|
249
215
|
case args.size
|
250
216
|
when 1
|
251
217
|
_random__to(*args)
|
252
218
|
when 2
|
253
|
-
|
219
|
+
_random__from(*args)
|
254
220
|
else
|
255
221
|
_random_(*args)
|
256
222
|
end
|
@@ -260,5 +226,32 @@ module Torch
|
|
260
226
|
_clamp_min_(min)
|
261
227
|
_clamp_max_(max)
|
262
228
|
end
|
229
|
+
|
230
|
+
private
|
231
|
+
|
232
|
+
def tensor_indexes(indexes)
|
233
|
+
indexes.map do |index|
|
234
|
+
case index
|
235
|
+
when Integer
|
236
|
+
TensorIndex.integer(index)
|
237
|
+
when Range
|
238
|
+
finish = index.end
|
239
|
+
if finish == -1 && !index.exclude_end?
|
240
|
+
finish = nil
|
241
|
+
else
|
242
|
+
finish += 1 unless index.exclude_end?
|
243
|
+
end
|
244
|
+
TensorIndex.slice(index.begin, finish)
|
245
|
+
when Tensor
|
246
|
+
TensorIndex.tensor(index)
|
247
|
+
when nil
|
248
|
+
TensorIndex.none
|
249
|
+
when true, false
|
250
|
+
TensorIndex.boolean(index)
|
251
|
+
else
|
252
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
+
end
|
254
|
+
end
|
255
|
+
end
|
263
256
|
end
|
264
257
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-08-
|
11
|
+
date: 2020-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|