torch-rb 0.3.0 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +27 -0
- data/README.md +6 -2
- data/ext/torch/ext.cpp +45 -4
- data/ext/torch/extconf.rb +3 -4
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +33 -0
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/tensor.rb +35 -41
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 93271ffd62be6e35c6ea3a2219a7bc3dccbe8489d6f4aca1a1f00f99bab1a4bb
|
4
|
+
data.tar.gz: df2755ac3e6221502430d780d116a1145c461f97857e7b1b2b809095afaad9e5
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9e02de90a7a83e5d4421941a0ceea69c6367f42be2e97e4229812f35c27f83475fbace86f42285128da724343dcfd85050b7846d81d43fce100749be0072ad4c
|
7
|
+
data.tar.gz: 884873c3c965f16b0a833087909019ebc6f228511a9b0ecbf4c436cc546e28012e63899651aeab6befddbe7b057676b170cd3eb858b997f42289c103252a2834
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,30 @@
|
|
1
|
+
## 0.3.5 (2020-09-04)
|
2
|
+
|
3
|
+
- Fixed error with data loader (due to `dtype` of `randperm`)
|
4
|
+
|
5
|
+
## 0.3.4 (2020-08-26)
|
6
|
+
|
7
|
+
- Added `Torch.clamp` method
|
8
|
+
|
9
|
+
## 0.3.3 (2020-08-25)
|
10
|
+
|
11
|
+
- Added spectral ops
|
12
|
+
- Fixed tensor indexing
|
13
|
+
|
14
|
+
## 0.3.2 (2020-08-24)
|
15
|
+
|
16
|
+
- Added `enable_grad` method
|
17
|
+
- Added `random_split` method
|
18
|
+
- Added `collate_fn` option to `DataLoader`
|
19
|
+
- Added `grad=` method to `Tensor`
|
20
|
+
- Fixed error with `grad` method when empty
|
21
|
+
- Fixed `EmbeddingBag`
|
22
|
+
|
23
|
+
## 0.3.1 (2020-08-17)
|
24
|
+
|
25
|
+
- Added `create_graph` and `retain_graph` options to `backward` method
|
26
|
+
- Fixed error when `set` not required
|
27
|
+
|
1
28
|
## 0.3.0 (2020-07-29)
|
2
29
|
|
3
30
|
- Updated LibTorch to 1.6.0
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -411,7 +415,7 @@ Here’s the list of compatible versions.
|
|
411
415
|
|
412
416
|
Torch.rb | LibTorch
|
413
417
|
--- | ---
|
414
|
-
0.3.0 | 1.6.0
|
418
|
+
0.3.0-0.3.4 | 1.6.0
|
415
419
|
0.2.0-0.2.7 | 1.5.0-1.5.1
|
416
420
|
0.1.8 | 1.4.0
|
417
421
|
0.1.0-0.1.7 | 1.3.1
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -58,6 +68,13 @@ void Init_ext()
|
|
58
68
|
return generator.seed();
|
59
69
|
});
|
60
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
61
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
62
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -330,6 +347,18 @@ void Init_ext()
|
|
330
347
|
.define_method("numel", &torch::Tensor::numel)
|
331
348
|
.define_method("element_size", &torch::Tensor::element_size)
|
332
349
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
350
|
+
.define_method(
|
351
|
+
"_index",
|
352
|
+
*[](Tensor& self, Array indices) {
|
353
|
+
auto vec = index_vector(indices);
|
354
|
+
return self.index(vec);
|
355
|
+
})
|
356
|
+
.define_method(
|
357
|
+
"_index_put_custom",
|
358
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
359
|
+
auto vec = index_vector(indices);
|
360
|
+
return self.index_put_(vec, value);
|
361
|
+
})
|
333
362
|
.define_method(
|
334
363
|
"contiguous?",
|
335
364
|
*[](Tensor& self) {
|
@@ -352,13 +381,19 @@ void Init_ext()
|
|
352
381
|
})
|
353
382
|
.define_method(
|
354
383
|
"_backward",
|
355
|
-
*[](Tensor& self,
|
356
|
-
return
|
384
|
+
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
385
|
+
return self.backward(gradient, create_graph, retain_graph);
|
357
386
|
})
|
358
387
|
.define_method(
|
359
388
|
"grad",
|
360
389
|
*[](Tensor& self) {
|
361
|
-
|
390
|
+
auto grad = self.grad();
|
391
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
392
|
+
})
|
393
|
+
.define_method(
|
394
|
+
"grad=",
|
395
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
396
|
+
self.grad() = grad;
|
362
397
|
})
|
363
398
|
.define_method(
|
364
399
|
"_dtype",
|
@@ -502,6 +537,7 @@ void Init_ext()
|
|
502
537
|
});
|
503
538
|
|
504
539
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
540
|
+
.add_handler<torch::Error>(handle_error)
|
505
541
|
.define_singleton_method(
|
506
542
|
"_calculate_gain",
|
507
543
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -580,11 +616,16 @@ void Init_ext()
|
|
580
616
|
*[](Parameter& self) {
|
581
617
|
auto grad = self.grad();
|
582
618
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
619
|
+
})
|
620
|
+
.define_method(
|
621
|
+
"grad=",
|
622
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
623
|
+
self.grad() = grad;
|
583
624
|
});
|
584
625
|
|
585
626
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
586
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
587
627
|
.add_handler<torch::Error>(handle_error)
|
628
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
588
629
|
.define_method("index", &torch::Device::index)
|
589
630
|
.define_method("index?", &torch::Device::has_index)
|
590
631
|
.define_method(
|
data/ext/torch/extconf.rb
CHANGED
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
8
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
|
-
|
11
|
-
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
10
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
12
11
|
|
13
12
|
# check omp first
|
14
13
|
if have_library("omp") || have_library("gomp")
|
15
14
|
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
-
$CXXFLAGS += " -Xclang" if
|
15
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
17
16
|
$CXXFLAGS += " -fopenmp"
|
18
17
|
end
|
19
18
|
|
20
|
-
if
|
19
|
+
if apple_clang
|
21
20
|
# silence ruby/intern.h warning
|
22
21
|
$CXXFLAGS += " -Wno-deprecated-register"
|
23
22
|
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,10 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::ScalarType;
|
14
|
+
using torch::Tensor;
|
15
|
+
|
12
16
|
// need to wrap torch::IntArrayRef() since
|
13
17
|
// it doesn't own underlying data
|
14
18
|
class IntArrayRef {
|
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
178
|
return MyReduction(x);
|
175
179
|
}
|
176
180
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
181
|
class OptionalTensor {
|
180
182
|
Object value;
|
181
183
|
public:
|
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
199
|
return OptionalTensor(x);
|
198
200
|
}
|
199
201
|
|
200
|
-
class ScalarType {
|
201
|
-
Object value;
|
202
|
-
public:
|
203
|
-
ScalarType(Object o) {
|
204
|
-
value = o;
|
205
|
-
}
|
206
|
-
operator at::ScalarType() {
|
207
|
-
throw std::runtime_error("ScalarType arguments not implemented yet");
|
208
|
-
}
|
209
|
-
};
|
210
|
-
|
211
202
|
template<>
|
212
203
|
inline
|
213
|
-
ScalarType from_ruby<ScalarType
|
204
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
214
205
|
{
|
215
|
-
|
206
|
+
if (x.is_nil()) {
|
207
|
+
return torch::nullopt;
|
208
|
+
} else {
|
209
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
210
|
+
}
|
216
211
|
}
|
217
212
|
|
218
|
-
class OptionalScalarType {
|
219
|
-
Object value;
|
220
|
-
public:
|
221
|
-
OptionalScalarType(Object o) {
|
222
|
-
value = o;
|
223
|
-
}
|
224
|
-
operator c10::optional<at::ScalarType>() {
|
225
|
-
if (value.is_nil()) {
|
226
|
-
return c10::nullopt;
|
227
|
-
}
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
231
|
-
|
232
213
|
template<>
|
233
214
|
inline
|
234
|
-
|
215
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
235
216
|
{
|
236
|
-
|
217
|
+
if (x.is_nil()) {
|
218
|
+
return torch::nullopt;
|
219
|
+
} else {
|
220
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
221
|
+
}
|
237
222
|
}
|
238
223
|
|
239
|
-
typedef torch::Device Device;
|
240
|
-
|
241
224
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
225
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
226
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -4,6 +4,7 @@ require "torch/ext"
|
|
4
4
|
# stdlib
|
5
5
|
require "fileutils"
|
6
6
|
require "net/http"
|
7
|
+
require "set"
|
7
8
|
require "tmpdir"
|
8
9
|
|
9
10
|
# native functions
|
@@ -178,8 +179,10 @@ require "torch/nn/functional"
|
|
178
179
|
require "torch/nn/init"
|
179
180
|
|
180
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
181
183
|
require "torch/utils/data/data_loader"
|
182
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
183
186
|
require "torch/utils/data/tensor_dataset"
|
184
187
|
|
185
188
|
# hub
|
@@ -315,6 +318,16 @@ module Torch
|
|
315
318
|
end
|
316
319
|
end
|
317
320
|
|
321
|
+
def enable_grad
|
322
|
+
previous_value = grad_enabled?
|
323
|
+
begin
|
324
|
+
_set_grad_enabled(true)
|
325
|
+
yield
|
326
|
+
ensure
|
327
|
+
_set_grad_enabled(previous_value)
|
328
|
+
end
|
329
|
+
end
|
330
|
+
|
318
331
|
def device(str)
|
319
332
|
Device.new(str)
|
320
333
|
end
|
@@ -375,6 +388,10 @@ module Torch
|
|
375
388
|
end
|
376
389
|
|
377
390
|
def randperm(n, **options)
|
391
|
+
# dtype hack in Python
|
392
|
+
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
393
|
+
options[:dtype] ||= :int64
|
394
|
+
|
378
395
|
_randperm(n, tensor_options(**options))
|
379
396
|
end
|
380
397
|
|
@@ -447,6 +464,22 @@ module Torch
|
|
447
464
|
zeros(input.size, **like_options(input, options))
|
448
465
|
end
|
449
466
|
|
467
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
468
|
+
if center
|
469
|
+
signal_dim = input.dim
|
470
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
471
|
+
pad = n_fft.div(2).to_i
|
472
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
473
|
+
input = input.view(input.shape[-signal_dim..-1])
|
474
|
+
end
|
475
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
476
|
+
end
|
477
|
+
|
478
|
+
def clamp(tensor, min, max)
|
479
|
+
tensor = _clamp_min(tensor, min)
|
480
|
+
_clamp_max(tensor, max)
|
481
|
+
end
|
482
|
+
|
450
483
|
private
|
451
484
|
|
452
485
|
def to_ivalue(obj)
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
@@ -1,10 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
+
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
+
@function["func"].sub!(tensor_options_str, ")")
|
8
12
|
end
|
9
13
|
|
10
14
|
def func
|
@@ -33,30 +33,14 @@ module Torch
|
|
33
33
|
f.args.any? do |a|
|
34
34
|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
+
# call to 'range' is ambiguous
|
37
|
+
f.cpp_name == "_range" ||
|
36
38
|
# native_functions.yaml is missing size argument for normal
|
37
39
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
40
|
(f.base_name == "normal" && !f.out?)
|
39
41
|
end
|
40
42
|
end
|
41
43
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
44
|
# todo_functions.each do |f|
|
61
45
|
# puts f.func
|
62
46
|
# puts
|
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
|
|
97
81
|
|
98
82
|
cpp_defs = []
|
99
83
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
84
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
85
|
+
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
101
86
|
|
102
87
|
cpp_args = []
|
103
88
|
fargs.each do |a|
|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
|
|
109
94
|
# TODO better signature
|
110
95
|
"OptionalTensor"
|
111
96
|
when "ScalarType?"
|
112
|
-
"
|
97
|
+
"torch::optional<ScalarType>"
|
113
98
|
when "Tensor[]"
|
114
99
|
"TensorList"
|
115
100
|
when "Tensor?[]"
|
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
|
|
117
102
|
"TensorList"
|
118
103
|
when "int"
|
119
104
|
"int64_t"
|
105
|
+
when "int?"
|
106
|
+
"torch::optional<int64_t>"
|
120
107
|
when "float"
|
121
108
|
"double"
|
122
109
|
when /\Aint\[/
|
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
|
|
125
112
|
"Tensor &"
|
126
113
|
when "str"
|
127
114
|
"std::string"
|
115
|
+
when "TensorOptions"
|
116
|
+
"const torch::TensorOptions &"
|
128
117
|
else
|
129
118
|
a[:type]
|
130
119
|
end
|
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,8 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
86
88
|
when "float"
|
87
89
|
v.is_a?(Numeric)
|
88
90
|
when /int\[.*\]/
|
@@ -126,9 +128,11 @@ module Torch
|
|
126
128
|
end
|
127
129
|
|
128
130
|
func = candidates.first
|
131
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
133
|
{
|
130
134
|
name: func.cpp_name,
|
131
|
-
args:
|
135
|
+
args: args
|
132
136
|
}
|
133
137
|
end
|
134
138
|
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -373,7 +373,8 @@ module Torch
|
|
373
373
|
end
|
374
374
|
|
375
375
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
376
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
377
|
+
ret
|
377
378
|
end
|
378
379
|
|
379
380
|
# distance functions
|
@@ -426,6 +427,9 @@ module Torch
|
|
426
427
|
end
|
427
428
|
|
428
429
|
def mse_loss(input, target, reduction: "mean")
|
430
|
+
if target.size != input.size
|
431
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
|
+
end
|
429
433
|
NN.mse_loss(input, target, reduction)
|
430
434
|
end
|
431
435
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -103,8 +103,9 @@ module Torch
|
|
103
103
|
Torch.empty(0, dtype: dtype)
|
104
104
|
end
|
105
105
|
|
106
|
-
def backward(gradient = nil)
|
107
|
-
|
106
|
+
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
+
retain_graph = create_graph if retain_graph.nil?
|
108
|
+
_backward(gradient, retain_graph, create_graph)
|
108
109
|
end
|
109
110
|
|
110
111
|
# TODO read directly from memory
|
@@ -187,49 +188,15 @@ module Torch
|
|
187
188
|
# based on python_variable_indexing.cpp and
|
188
189
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
189
190
|
def [](*indexes)
|
190
|
-
|
191
|
-
dim = 0
|
192
|
-
indexes.each do |index|
|
193
|
-
if index.is_a?(Numeric)
|
194
|
-
result = result._select_int(dim, index)
|
195
|
-
elsif index.is_a?(Range)
|
196
|
-
finish = index.end
|
197
|
-
finish += 1 unless index.exclude_end?
|
198
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
199
|
-
dim += 1
|
200
|
-
elsif index.is_a?(Tensor)
|
201
|
-
result = result.index([index])
|
202
|
-
elsif index.nil?
|
203
|
-
result = result.unsqueeze(dim)
|
204
|
-
dim += 1
|
205
|
-
elsif index == true
|
206
|
-
result = result.unsqueeze(dim)
|
207
|
-
# TODO handle false
|
208
|
-
else
|
209
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
210
|
-
end
|
211
|
-
end
|
212
|
-
result
|
191
|
+
_index(tensor_indexes(indexes))
|
213
192
|
end
|
214
193
|
|
215
194
|
# based on python_variable_indexing.cpp and
|
216
195
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
217
|
-
def []=(
|
196
|
+
def []=(*indexes, value)
|
218
197
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
219
|
-
|
220
198
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
221
|
-
|
222
|
-
if index.is_a?(Numeric)
|
223
|
-
index_put!([Torch.tensor(index)], value)
|
224
|
-
elsif index.is_a?(Range)
|
225
|
-
finish = index.end
|
226
|
-
finish += 1 unless index.exclude_end?
|
227
|
-
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
228
|
-
elsif index.is_a?(Tensor)
|
229
|
-
index_put!([index], value)
|
230
|
-
else
|
231
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
232
|
-
end
|
199
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
233
200
|
end
|
234
201
|
|
235
202
|
# native functions that need manually defined
|
@@ -243,13 +210,13 @@ module Torch
|
|
243
210
|
end
|
244
211
|
end
|
245
212
|
|
246
|
-
#
|
213
|
+
# parser can't handle overlap, so need to handle manually
|
247
214
|
def random!(*args)
|
248
215
|
case args.size
|
249
216
|
when 1
|
250
217
|
_random__to(*args)
|
251
218
|
when 2
|
252
|
-
|
219
|
+
_random__from(*args)
|
253
220
|
else
|
254
221
|
_random_(*args)
|
255
222
|
end
|
@@ -259,5 +226,32 @@ module Torch
|
|
259
226
|
_clamp_min_(min)
|
260
227
|
_clamp_max_(max)
|
261
228
|
end
|
229
|
+
|
230
|
+
private
|
231
|
+
|
232
|
+
def tensor_indexes(indexes)
|
233
|
+
indexes.map do |index|
|
234
|
+
case index
|
235
|
+
when Integer
|
236
|
+
TensorIndex.integer(index)
|
237
|
+
when Range
|
238
|
+
finish = index.end
|
239
|
+
if finish == -1 && !index.exclude_end?
|
240
|
+
finish = nil
|
241
|
+
else
|
242
|
+
finish += 1 unless index.exclude_end?
|
243
|
+
end
|
244
|
+
TensorIndex.slice(index.begin, finish)
|
245
|
+
when Tensor
|
246
|
+
TensorIndex.tensor(index)
|
247
|
+
when nil
|
248
|
+
TensorIndex.none
|
249
|
+
when true, false
|
250
|
+
TensorIndex.boolean(index)
|
251
|
+
else
|
252
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
+
end
|
254
|
+
end
|
255
|
+
end
|
262
256
|
end
|
263
257
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,8 +37,8 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
@@ -36,7 +48,7 @@ module Torch
|
|
36
48
|
|
37
49
|
private
|
38
50
|
|
39
|
-
def
|
51
|
+
def default_convert(batch)
|
40
52
|
elem = batch[0]
|
41
53
|
case elem
|
42
54
|
when Tensor
|
@@ -44,11 +56,15 @@ module Torch
|
|
44
56
|
when Integer
|
45
57
|
Torch.tensor(batch)
|
46
58
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
59
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
60
|
else
|
49
|
-
raise
|
61
|
+
raise NotImplementedYet
|
50
62
|
end
|
51
63
|
end
|
64
|
+
|
65
|
+
def auto_collation?
|
66
|
+
!@batch_sampler.nil?
|
67
|
+
end
|
52
68
|
end
|
53
69
|
end
|
54
70
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-09-04 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|