torch-rb 0.2.7 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -2
- data/README.md +9 -2
- data/ext/torch/ext.cpp +49 -7
- data/ext/torch/extconf.rb +3 -4
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +30 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +39 -46
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: fa01dc7cfd168494f1066f4e692eb08dee7f419126b2d10a6eb5b2b22fe01526
|
4
|
+
data.tar.gz: 9e69a7da9ecc85c51bd9cfb42b93f1b1aa6148958ed8bc16ffb0555b6d42159d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 36bc620212235e57a40d9e791b3eae85fbd73c7224a5445e2e94e49c7a3ec4cb638024193d8e47817e577a72ef96976fdf2851f3607fd3ea1712255bca0e1ec1
|
7
|
+
data.tar.gz: 888971d06717b610020644ed720ae48fa645a90790d54692ac64a565434b550ce4f778688f0a22806f018656dd9d6183d6510fc458c8c5e82c0629951489279c
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.4 (2020-08-26)
|
2
|
+
|
3
|
+
- Added `Torch.clamp` method
|
4
|
+
|
5
|
+
## 0.3.3 (2020-08-25)
|
6
|
+
|
7
|
+
- Added spectral ops
|
8
|
+
- Fixed tensor indexing
|
9
|
+
|
10
|
+
## 0.3.2 (2020-08-24)
|
11
|
+
|
12
|
+
- Added `enable_grad` method
|
13
|
+
- Added `random_split` method
|
14
|
+
- Added `collate_fn` option to `DataLoader`
|
15
|
+
- Added `grad=` method to `Tensor`
|
16
|
+
- Fixed error with `grad` method when empty
|
17
|
+
- Fixed `EmbeddingBag`
|
18
|
+
|
19
|
+
## 0.3.1 (2020-08-17)
|
20
|
+
|
21
|
+
- Added `create_graph` and `retain_graph` options to `backward` method
|
22
|
+
- Fixed error when `set` not required
|
23
|
+
|
24
|
+
## 0.3.0 (2020-07-29)
|
25
|
+
|
26
|
+
- Updated LibTorch to 1.6.0
|
27
|
+
- Removed `state_dict` method from optimizers until `load_state_dict` is implemented
|
28
|
+
|
1
29
|
## 0.2.7 (2020-06-29)
|
2
30
|
|
3
31
|
- Made tensors enumerable
|
@@ -43,7 +71,7 @@
|
|
43
71
|
## 0.2.0 (2020-04-22)
|
44
72
|
|
45
73
|
- No longer experimental
|
46
|
-
- Updated
|
74
|
+
- Updated LibTorch to 1.5.0
|
47
75
|
- Added support for GPUs and OpenMP
|
48
76
|
- Added adaptive pooling layers
|
49
77
|
- Tensor `dtype` is now based on Numo type for `Torch.tensor`
|
@@ -52,7 +80,7 @@
|
|
52
80
|
|
53
81
|
## 0.1.8 (2020-01-17)
|
54
82
|
|
55
|
-
- Updated
|
83
|
+
- Updated LibTorch to 1.4.0
|
56
84
|
|
57
85
|
## 0.1.7 (2020-01-10)
|
58
86
|
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -42,6 +46,8 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
42
46
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
43
47
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
44
48
|
|
49
|
+
You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
|
50
|
+
|
45
51
|
## Tutorial
|
46
52
|
|
47
53
|
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
@@ -409,7 +415,8 @@ Here’s the list of compatible versions.
|
|
409
415
|
|
410
416
|
Torch.rb | LibTorch
|
411
417
|
--- | ---
|
412
|
-
0.
|
418
|
+
0.3.0-0.3.4 | 1.6.0
|
419
|
+
0.2.0-0.2.7 | 1.5.0-1.5.1
|
413
420
|
0.1.8 | 1.4.0
|
414
421
|
0.1.0-0.1.7 | 1.3.1
|
415
422
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -48,15 +58,23 @@ void Init_ext()
|
|
48
58
|
.define_singleton_method(
|
49
59
|
"initial_seed",
|
50
60
|
*[]() {
|
51
|
-
return at::detail::getDefaultCPUGenerator()
|
61
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
52
62
|
})
|
53
63
|
.define_singleton_method(
|
54
64
|
"seed",
|
55
65
|
*[]() {
|
56
66
|
// TODO set for CUDA when available
|
57
|
-
|
67
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
68
|
+
return generator.seed();
|
58
69
|
});
|
59
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
60
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
61
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -329,6 +347,18 @@ void Init_ext()
|
|
329
347
|
.define_method("numel", &torch::Tensor::numel)
|
330
348
|
.define_method("element_size", &torch::Tensor::element_size)
|
331
349
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
350
|
+
.define_method(
|
351
|
+
"_index",
|
352
|
+
*[](Tensor& self, Array indices) {
|
353
|
+
auto vec = index_vector(indices);
|
354
|
+
return self.index(vec);
|
355
|
+
})
|
356
|
+
.define_method(
|
357
|
+
"_index_put_custom",
|
358
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
359
|
+
auto vec = index_vector(indices);
|
360
|
+
return self.index_put_(vec, value);
|
361
|
+
})
|
332
362
|
.define_method(
|
333
363
|
"contiguous?",
|
334
364
|
*[](Tensor& self) {
|
@@ -351,13 +381,19 @@ void Init_ext()
|
|
351
381
|
})
|
352
382
|
.define_method(
|
353
383
|
"_backward",
|
354
|
-
*[](Tensor& self,
|
355
|
-
return
|
384
|
+
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
385
|
+
return self.backward(gradient, create_graph, retain_graph);
|
356
386
|
})
|
357
387
|
.define_method(
|
358
388
|
"grad",
|
359
389
|
*[](Tensor& self) {
|
360
|
-
|
390
|
+
auto grad = self.grad();
|
391
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
392
|
+
})
|
393
|
+
.define_method(
|
394
|
+
"grad=",
|
395
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
396
|
+
self.grad() = grad;
|
361
397
|
})
|
362
398
|
.define_method(
|
363
399
|
"_dtype",
|
@@ -460,7 +496,7 @@ void Init_ext()
|
|
460
496
|
.define_singleton_method(
|
461
497
|
"_make_subclass",
|
462
498
|
*[](Tensor& rd, bool requires_grad) {
|
463
|
-
auto data =
|
499
|
+
auto data = rd.detach();
|
464
500
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
465
501
|
auto var = data.set_requires_grad(requires_grad);
|
466
502
|
return Parameter(std::move(var));
|
@@ -501,6 +537,7 @@ void Init_ext()
|
|
501
537
|
});
|
502
538
|
|
503
539
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
540
|
+
.add_handler<torch::Error>(handle_error)
|
504
541
|
.define_singleton_method(
|
505
542
|
"_calculate_gain",
|
506
543
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -579,11 +616,16 @@ void Init_ext()
|
|
579
616
|
*[](Parameter& self) {
|
580
617
|
auto grad = self.grad();
|
581
618
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
619
|
+
})
|
620
|
+
.define_method(
|
621
|
+
"grad=",
|
622
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
623
|
+
self.grad() = grad;
|
582
624
|
});
|
583
625
|
|
584
626
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
585
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
586
627
|
.add_handler<torch::Error>(handle_error)
|
628
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
587
629
|
.define_method("index", &torch::Device::index)
|
588
630
|
.define_method("index?", &torch::Device::has_index)
|
589
631
|
.define_method(
|
data/ext/torch/extconf.rb
CHANGED
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
8
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
|
-
|
11
|
-
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
10
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
12
11
|
|
13
12
|
# check omp first
|
14
13
|
if have_library("omp") || have_library("gomp")
|
15
14
|
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
-
$CXXFLAGS += " -Xclang" if
|
15
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
17
16
|
$CXXFLAGS += " -fopenmp"
|
18
17
|
end
|
19
18
|
|
20
|
-
if
|
19
|
+
if apple_clang
|
21
20
|
# silence ruby/intern.h warning
|
22
21
|
$CXXFLAGS += " -Wno-deprecated-register"
|
23
22
|
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,10 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::ScalarType;
|
14
|
+
using torch::Tensor;
|
15
|
+
|
12
16
|
// need to wrap torch::IntArrayRef() since
|
13
17
|
// it doesn't own underlying data
|
14
18
|
class IntArrayRef {
|
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
178
|
return MyReduction(x);
|
175
179
|
}
|
176
180
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
181
|
class OptionalTensor {
|
180
182
|
Object value;
|
181
183
|
public:
|
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
199
|
return OptionalTensor(x);
|
198
200
|
}
|
199
201
|
|
200
|
-
class ScalarType {
|
201
|
-
Object value;
|
202
|
-
public:
|
203
|
-
ScalarType(Object o) {
|
204
|
-
value = o;
|
205
|
-
}
|
206
|
-
operator at::ScalarType() {
|
207
|
-
throw std::runtime_error("ScalarType arguments not implemented yet");
|
208
|
-
}
|
209
|
-
};
|
210
|
-
|
211
202
|
template<>
|
212
203
|
inline
|
213
|
-
ScalarType from_ruby<ScalarType
|
204
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
214
205
|
{
|
215
|
-
|
206
|
+
if (x.is_nil()) {
|
207
|
+
return torch::nullopt;
|
208
|
+
} else {
|
209
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
210
|
+
}
|
216
211
|
}
|
217
212
|
|
218
|
-
class OptionalScalarType {
|
219
|
-
Object value;
|
220
|
-
public:
|
221
|
-
OptionalScalarType(Object o) {
|
222
|
-
value = o;
|
223
|
-
}
|
224
|
-
operator c10::optional<at::ScalarType>() {
|
225
|
-
if (value.is_nil()) {
|
226
|
-
return c10::nullopt;
|
227
|
-
}
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
231
|
-
|
232
213
|
template<>
|
233
214
|
inline
|
234
|
-
|
215
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
235
216
|
{
|
236
|
-
|
217
|
+
if (x.is_nil()) {
|
218
|
+
return torch::nullopt;
|
219
|
+
} else {
|
220
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
221
|
+
}
|
237
222
|
}
|
238
223
|
|
239
|
-
typedef torch::Device Device;
|
240
|
-
|
241
224
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
225
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
226
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -4,6 +4,7 @@ require "torch/ext"
|
|
4
4
|
# stdlib
|
5
5
|
require "fileutils"
|
6
6
|
require "net/http"
|
7
|
+
require "set"
|
7
8
|
require "tmpdir"
|
8
9
|
|
9
10
|
# native functions
|
@@ -178,8 +179,10 @@ require "torch/nn/functional"
|
|
178
179
|
require "torch/nn/init"
|
179
180
|
|
180
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
181
183
|
require "torch/utils/data/data_loader"
|
182
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
183
186
|
require "torch/utils/data/tensor_dataset"
|
184
187
|
|
185
188
|
# hub
|
@@ -315,6 +318,16 @@ module Torch
|
|
315
318
|
end
|
316
319
|
end
|
317
320
|
|
321
|
+
def enable_grad
|
322
|
+
previous_value = grad_enabled?
|
323
|
+
begin
|
324
|
+
_set_grad_enabled(true)
|
325
|
+
yield
|
326
|
+
ensure
|
327
|
+
_set_grad_enabled(previous_value)
|
328
|
+
end
|
329
|
+
end
|
330
|
+
|
318
331
|
def device(str)
|
319
332
|
Device.new(str)
|
320
333
|
end
|
@@ -447,6 +460,22 @@ module Torch
|
|
447
460
|
zeros(input.size, **like_options(input, options))
|
448
461
|
end
|
449
462
|
|
463
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
464
|
+
if center
|
465
|
+
signal_dim = input.dim
|
466
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
467
|
+
pad = n_fft.div(2).to_i
|
468
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
469
|
+
input = input.view(input.shape[-signal_dim..-1])
|
470
|
+
end
|
471
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
472
|
+
end
|
473
|
+
|
474
|
+
def clamp(tensor, min, max)
|
475
|
+
tensor = _clamp_min(tensor, min)
|
476
|
+
_clamp_max(tensor, max)
|
477
|
+
end
|
478
|
+
|
450
479
|
private
|
451
480
|
|
452
481
|
def to_ivalue(obj)
|
@@ -470,11 +499,7 @@ module Torch
|
|
470
499
|
when nil
|
471
500
|
IValue.new
|
472
501
|
when Array
|
473
|
-
|
474
|
-
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
475
|
-
else
|
476
|
-
raise Error, "Unknown list type"
|
477
|
-
end
|
502
|
+
IValue.from_list(obj.map { |v| to_ivalue(v) })
|
478
503
|
else
|
479
504
|
raise Error, "Unknown type: #{obj.class.name}"
|
480
505
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
@@ -1,10 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
+
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
+
@function["func"].sub!(tensor_options_str, ")")
|
8
12
|
end
|
9
13
|
|
10
14
|
def func
|
@@ -33,30 +33,14 @@ module Torch
|
|
33
33
|
f.args.any? do |a|
|
34
34
|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
+
# call to 'range' is ambiguous
|
37
|
+
f.cpp_name == "_range" ||
|
36
38
|
# native_functions.yaml is missing size argument for normal
|
37
39
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
40
|
(f.base_name == "normal" && !f.out?)
|
39
41
|
end
|
40
42
|
end
|
41
43
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
44
|
# todo_functions.each do |f|
|
61
45
|
# puts f.func
|
62
46
|
# puts
|
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
|
|
97
81
|
|
98
82
|
cpp_defs = []
|
99
83
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
84
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
85
|
+
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
101
86
|
|
102
87
|
cpp_args = []
|
103
88
|
fargs.each do |a|
|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
|
|
109
94
|
# TODO better signature
|
110
95
|
"OptionalTensor"
|
111
96
|
when "ScalarType?"
|
112
|
-
"
|
97
|
+
"torch::optional<ScalarType>"
|
113
98
|
when "Tensor[]"
|
114
99
|
"TensorList"
|
115
100
|
when "Tensor?[]"
|
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
|
|
117
102
|
"TensorList"
|
118
103
|
when "int"
|
119
104
|
"int64_t"
|
105
|
+
when "int?"
|
106
|
+
"torch::optional<int64_t>"
|
120
107
|
when "float"
|
121
108
|
"double"
|
122
109
|
when /\Aint\[/
|
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
|
|
125
112
|
"Tensor &"
|
126
113
|
when "str"
|
127
114
|
"std::string"
|
115
|
+
when "TensorOptions"
|
116
|
+
"const torch::TensorOptions &"
|
128
117
|
else
|
129
118
|
a[:type]
|
130
119
|
end
|