torch-rb 0.3.2 → 0.3.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +5 -1
- data/ext/torch/ext.cpp +31 -1
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +11 -0
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/tensor.rb +32 -39
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f17982ebcf982b779b2c88dc0ef13a9c94b2e9a6a87007a0c3136ad5f2ef261b
|
4
|
+
data.tar.gz: d4ddddfc7cbb9baee0e117e36decfe8a757903f876ddcf510db4db53363f7adf
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 00ff39bb405350f0107974b1786bf14ac6ca558744f05653ee2f16445d46f66658f82dcfa069bd3a92b44c33ba642dd196fc29a21379f188111fd7e8648f5eab
|
7
|
+
data.tar.gz: b66d48682789f71032c1d928174829791df43a04e829ebc241cafab884bc076983323bdffbcb12f5785c6d4614c13d2fed32bb5f20106957a4d6326289b7a1ea
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -58,6 +68,13 @@ void Init_ext()
|
|
58
68
|
return generator.seed();
|
59
69
|
});
|
60
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
61
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
62
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -330,6 +347,18 @@ void Init_ext()
|
|
330
347
|
.define_method("numel", &torch::Tensor::numel)
|
331
348
|
.define_method("element_size", &torch::Tensor::element_size)
|
332
349
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
350
|
+
.define_method(
|
351
|
+
"_index",
|
352
|
+
*[](Tensor& self, Array indices) {
|
353
|
+
auto vec = index_vector(indices);
|
354
|
+
return self.index(vec);
|
355
|
+
})
|
356
|
+
.define_method(
|
357
|
+
"_index_put_custom",
|
358
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
359
|
+
auto vec = index_vector(indices);
|
360
|
+
return self.index_put_(vec, value);
|
361
|
+
})
|
333
362
|
.define_method(
|
334
363
|
"contiguous?",
|
335
364
|
*[](Tensor& self) {
|
@@ -508,6 +537,7 @@ void Init_ext()
|
|
508
537
|
});
|
509
538
|
|
510
539
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
540
|
+
.add_handler<torch::Error>(handle_error)
|
511
541
|
.define_singleton_method(
|
512
542
|
"_calculate_gain",
|
513
543
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -594,8 +624,8 @@ void Init_ext()
|
|
594
624
|
});
|
595
625
|
|
596
626
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
597
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
598
627
|
.add_handler<torch::Error>(handle_error)
|
628
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
599
629
|
.define_method("index", &torch::Device::index)
|
600
630
|
.define_method("index?", &torch::Device::has_index)
|
601
631
|
.define_method(
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,10 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::ScalarType;
|
14
|
+
using torch::Tensor;
|
15
|
+
|
12
16
|
// need to wrap torch::IntArrayRef() since
|
13
17
|
// it doesn't own underlying data
|
14
18
|
class IntArrayRef {
|
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
178
|
return MyReduction(x);
|
175
179
|
}
|
176
180
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
181
|
class OptionalTensor {
|
180
182
|
Object value;
|
181
183
|
public:
|
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
199
|
return OptionalTensor(x);
|
198
200
|
}
|
199
201
|
|
200
|
-
class ScalarType {
|
201
|
-
Object value;
|
202
|
-
public:
|
203
|
-
ScalarType(Object o) {
|
204
|
-
value = o;
|
205
|
-
}
|
206
|
-
operator at::ScalarType() {
|
207
|
-
throw std::runtime_error("ScalarType arguments not implemented yet");
|
208
|
-
}
|
209
|
-
};
|
210
|
-
|
211
202
|
template<>
|
212
203
|
inline
|
213
|
-
ScalarType from_ruby<ScalarType
|
204
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
214
205
|
{
|
215
|
-
|
206
|
+
if (x.is_nil()) {
|
207
|
+
return torch::nullopt;
|
208
|
+
} else {
|
209
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
210
|
+
}
|
216
211
|
}
|
217
212
|
|
218
|
-
class OptionalScalarType {
|
219
|
-
Object value;
|
220
|
-
public:
|
221
|
-
OptionalScalarType(Object o) {
|
222
|
-
value = o;
|
223
|
-
}
|
224
|
-
operator c10::optional<at::ScalarType>() {
|
225
|
-
if (value.is_nil()) {
|
226
|
-
return c10::nullopt;
|
227
|
-
}
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
231
|
-
|
232
213
|
template<>
|
233
214
|
inline
|
234
|
-
|
215
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
235
216
|
{
|
236
|
-
|
217
|
+
if (x.is_nil()) {
|
218
|
+
return torch::nullopt;
|
219
|
+
} else {
|
220
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
221
|
+
}
|
237
222
|
}
|
238
223
|
|
239
|
-
typedef torch::Device Device;
|
240
|
-
|
241
224
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
225
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
226
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -460,6 +460,17 @@ module Torch
|
|
460
460
|
zeros(input.size, **like_options(input, options))
|
461
461
|
end
|
462
462
|
|
463
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
464
|
+
if center
|
465
|
+
signal_dim = input.dim
|
466
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
467
|
+
pad = n_fft.div(2).to_i
|
468
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
469
|
+
input = input.view(input.shape[-signal_dim..-1])
|
470
|
+
end
|
471
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
472
|
+
end
|
473
|
+
|
463
474
|
private
|
464
475
|
|
465
476
|
def to_ivalue(obj)
|
@@ -1,10 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
+
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
+
@function["func"].sub!(tensor_options_str, ")")
|
8
12
|
end
|
9
13
|
|
10
14
|
def func
|
@@ -33,30 +33,14 @@ module Torch
|
|
33
33
|
f.args.any? do |a|
|
34
34
|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
+
# call to 'range' is ambiguous
|
37
|
+
f.cpp_name == "_range" ||
|
36
38
|
# native_functions.yaml is missing size argument for normal
|
37
39
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
40
|
(f.base_name == "normal" && !f.out?)
|
39
41
|
end
|
40
42
|
end
|
41
43
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
44
|
# todo_functions.each do |f|
|
61
45
|
# puts f.func
|
62
46
|
# puts
|
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
|
|
97
81
|
|
98
82
|
cpp_defs = []
|
99
83
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
84
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
85
|
+
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
101
86
|
|
102
87
|
cpp_args = []
|
103
88
|
fargs.each do |a|
|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
|
|
109
94
|
# TODO better signature
|
110
95
|
"OptionalTensor"
|
111
96
|
when "ScalarType?"
|
112
|
-
"
|
97
|
+
"torch::optional<ScalarType>"
|
113
98
|
when "Tensor[]"
|
114
99
|
"TensorList"
|
115
100
|
when "Tensor?[]"
|
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
|
|
117
102
|
"TensorList"
|
118
103
|
when "int"
|
119
104
|
"int64_t"
|
105
|
+
when "int?"
|
106
|
+
"torch::optional<int64_t>"
|
120
107
|
when "float"
|
121
108
|
"double"
|
122
109
|
when /\Aint\[/
|
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
|
|
125
112
|
"Tensor &"
|
126
113
|
when "str"
|
127
114
|
"std::string"
|
115
|
+
when "TensorOptions"
|
116
|
+
"const torch::TensorOptions &"
|
128
117
|
else
|
129
118
|
a[:type]
|
130
119
|
end
|
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,8 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
86
88
|
when "float"
|
87
89
|
v.is_a?(Numeric)
|
88
90
|
when /int\[.*\]/
|
@@ -126,9 +128,11 @@ module Torch
|
|
126
128
|
end
|
127
129
|
|
128
130
|
func = candidates.first
|
131
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
133
|
{
|
130
134
|
name: func.cpp_name,
|
131
|
-
args:
|
135
|
+
args: args
|
132
136
|
}
|
133
137
|
end
|
134
138
|
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -188,49 +188,15 @@ module Torch
|
|
188
188
|
# based on python_variable_indexing.cpp and
|
189
189
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
190
|
def [](*indexes)
|
191
|
-
|
192
|
-
dim = 0
|
193
|
-
indexes.each do |index|
|
194
|
-
if index.is_a?(Numeric)
|
195
|
-
result = result._select_int(dim, index)
|
196
|
-
elsif index.is_a?(Range)
|
197
|
-
finish = index.end
|
198
|
-
finish += 1 unless index.exclude_end?
|
199
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
200
|
-
dim += 1
|
201
|
-
elsif index.is_a?(Tensor)
|
202
|
-
result = result.index([index])
|
203
|
-
elsif index.nil?
|
204
|
-
result = result.unsqueeze(dim)
|
205
|
-
dim += 1
|
206
|
-
elsif index == true
|
207
|
-
result = result.unsqueeze(dim)
|
208
|
-
# TODO handle false
|
209
|
-
else
|
210
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
211
|
-
end
|
212
|
-
end
|
213
|
-
result
|
191
|
+
_index(tensor_indexes(indexes))
|
214
192
|
end
|
215
193
|
|
216
194
|
# based on python_variable_indexing.cpp and
|
217
195
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
218
|
-
def []=(
|
196
|
+
def []=(*indexes, value)
|
219
197
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
220
|
-
|
221
198
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
222
|
-
|
223
|
-
if index.is_a?(Numeric)
|
224
|
-
index_put!([Torch.tensor(index)], value)
|
225
|
-
elsif index.is_a?(Range)
|
226
|
-
finish = index.end
|
227
|
-
finish += 1 unless index.exclude_end?
|
228
|
-
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
229
|
-
elsif index.is_a?(Tensor)
|
230
|
-
index_put!([index], value)
|
231
|
-
else
|
232
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
233
|
-
end
|
199
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
234
200
|
end
|
235
201
|
|
236
202
|
# native functions that need manually defined
|
@@ -244,13 +210,13 @@ module Torch
|
|
244
210
|
end
|
245
211
|
end
|
246
212
|
|
247
|
-
#
|
213
|
+
# parser can't handle overlap, so need to handle manually
|
248
214
|
def random!(*args)
|
249
215
|
case args.size
|
250
216
|
when 1
|
251
217
|
_random__to(*args)
|
252
218
|
when 2
|
253
|
-
|
219
|
+
_random__from(*args)
|
254
220
|
else
|
255
221
|
_random_(*args)
|
256
222
|
end
|
@@ -260,5 +226,32 @@ module Torch
|
|
260
226
|
_clamp_min_(min)
|
261
227
|
_clamp_max_(max)
|
262
228
|
end
|
229
|
+
|
230
|
+
private
|
231
|
+
|
232
|
+
def tensor_indexes(indexes)
|
233
|
+
indexes.map do |index|
|
234
|
+
case index
|
235
|
+
when Integer
|
236
|
+
TensorIndex.integer(index)
|
237
|
+
when Range
|
238
|
+
finish = index.end
|
239
|
+
if finish == -1 && !index.exclude_end?
|
240
|
+
finish = nil
|
241
|
+
else
|
242
|
+
finish += 1 unless index.exclude_end?
|
243
|
+
end
|
244
|
+
TensorIndex.slice(index.begin, finish)
|
245
|
+
when Tensor
|
246
|
+
TensorIndex.tensor(index)
|
247
|
+
when nil
|
248
|
+
TensorIndex.none
|
249
|
+
when true, false
|
250
|
+
TensorIndex.boolean(index)
|
251
|
+
else
|
252
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
+
end
|
254
|
+
end
|
255
|
+
end
|
263
256
|
end
|
264
257
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-08-
|
11
|
+
date: 2020-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|