torch-rb 0.2.7 → 0.3.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -2
- data/README.md +9 -2
- data/ext/torch/ext.cpp +49 -7
- data/ext/torch/extconf.rb +3 -4
- data/ext/torch/templates.hpp +16 -33
- data/lib/torch.rb +30 -5
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/native/function.rb +5 -1
- data/lib/torch/native/generator.rb +9 -20
- data/lib/torch/native/native_functions.yaml +654 -660
- data/lib/torch/native/parser.rb +5 -1
- data/lib/torch/nn/conv2d.rb +0 -1
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/optim/optimizer.rb +6 -4
- data/lib/torch/tensor.rb +39 -46
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: fa01dc7cfd168494f1066f4e692eb08dee7f419126b2d10a6eb5b2b22fe01526
|
4
|
+
data.tar.gz: 9e69a7da9ecc85c51bd9cfb42b93f1b1aa6148958ed8bc16ffb0555b6d42159d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 36bc620212235e57a40d9e791b3eae85fbd73c7224a5445e2e94e49c7a3ec4cb638024193d8e47817e577a72ef96976fdf2851f3607fd3ea1712255bca0e1ec1
|
7
|
+
data.tar.gz: 888971d06717b610020644ed720ae48fa645a90790d54692ac64a565434b550ce4f778688f0a22806f018656dd9d6183d6510fc458c8c5e82c0629951489279c
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.4 (2020-08-26)
|
2
|
+
|
3
|
+
- Added `Torch.clamp` method
|
4
|
+
|
5
|
+
## 0.3.3 (2020-08-25)
|
6
|
+
|
7
|
+
- Added spectral ops
|
8
|
+
- Fixed tensor indexing
|
9
|
+
|
10
|
+
## 0.3.2 (2020-08-24)
|
11
|
+
|
12
|
+
- Added `enable_grad` method
|
13
|
+
- Added `random_split` method
|
14
|
+
- Added `collate_fn` option to `DataLoader`
|
15
|
+
- Added `grad=` method to `Tensor`
|
16
|
+
- Fixed error with `grad` method when empty
|
17
|
+
- Fixed `EmbeddingBag`
|
18
|
+
|
19
|
+
## 0.3.1 (2020-08-17)
|
20
|
+
|
21
|
+
- Added `create_graph` and `retain_graph` options to `backward` method
|
22
|
+
- Fixed error when `set` not required
|
23
|
+
|
24
|
+
## 0.3.0 (2020-07-29)
|
25
|
+
|
26
|
+
- Updated LibTorch to 1.6.0
|
27
|
+
- Removed `state_dict` method from optimizers until `load_state_dict` is implemented
|
28
|
+
|
1
29
|
## 0.2.7 (2020-06-29)
|
2
30
|
|
3
31
|
- Made tensors enumerable
|
@@ -43,7 +71,7 @@
|
|
43
71
|
## 0.2.0 (2020-04-22)
|
44
72
|
|
45
73
|
- No longer experimental
|
46
|
-
- Updated
|
74
|
+
- Updated LibTorch to 1.5.0
|
47
75
|
- Added support for GPUs and OpenMP
|
48
76
|
- Added adaptive pooling layers
|
49
77
|
- Tensor `dtype` is now based on Numo type for `Torch.tensor`
|
@@ -52,7 +80,7 @@
|
|
52
80
|
|
53
81
|
## 0.1.8 (2020-01-17)
|
54
82
|
|
55
|
-
- Updated
|
83
|
+
- Updated LibTorch to 1.4.0
|
56
84
|
|
57
85
|
## 0.1.7 (2020-01-10)
|
58
86
|
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -42,6 +46,8 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
42
46
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
43
47
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
44
48
|
|
49
|
+
You can follow PyTorch tutorials and convert the code to Ruby in many cases. Feel free to open an issue if you run into problems.
|
50
|
+
|
45
51
|
## Tutorial
|
46
52
|
|
47
53
|
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
@@ -409,7 +415,8 @@ Here’s the list of compatible versions.
|
|
409
415
|
|
410
416
|
Torch.rb | LibTorch
|
411
417
|
--- | ---
|
412
|
-
0.
|
418
|
+
0.3.0-0.3.4 | 1.6.0
|
419
|
+
0.2.0-0.2.7 | 1.5.0-1.5.1
|
413
420
|
0.1.8 | 1.4.0
|
414
421
|
0.1.0-0.1.7 | 1.3.1
|
415
422
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -48,15 +58,23 @@ void Init_ext()
|
|
48
58
|
.define_singleton_method(
|
49
59
|
"initial_seed",
|
50
60
|
*[]() {
|
51
|
-
return at::detail::getDefaultCPUGenerator()
|
61
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
52
62
|
})
|
53
63
|
.define_singleton_method(
|
54
64
|
"seed",
|
55
65
|
*[]() {
|
56
66
|
// TODO set for CUDA when available
|
57
|
-
|
67
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
68
|
+
return generator.seed();
|
58
69
|
});
|
59
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
60
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
61
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -329,6 +347,18 @@ void Init_ext()
|
|
329
347
|
.define_method("numel", &torch::Tensor::numel)
|
330
348
|
.define_method("element_size", &torch::Tensor::element_size)
|
331
349
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
350
|
+
.define_method(
|
351
|
+
"_index",
|
352
|
+
*[](Tensor& self, Array indices) {
|
353
|
+
auto vec = index_vector(indices);
|
354
|
+
return self.index(vec);
|
355
|
+
})
|
356
|
+
.define_method(
|
357
|
+
"_index_put_custom",
|
358
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
359
|
+
auto vec = index_vector(indices);
|
360
|
+
return self.index_put_(vec, value);
|
361
|
+
})
|
332
362
|
.define_method(
|
333
363
|
"contiguous?",
|
334
364
|
*[](Tensor& self) {
|
@@ -351,13 +381,19 @@ void Init_ext()
|
|
351
381
|
})
|
352
382
|
.define_method(
|
353
383
|
"_backward",
|
354
|
-
*[](Tensor& self,
|
355
|
-
return
|
384
|
+
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
385
|
+
return self.backward(gradient, create_graph, retain_graph);
|
356
386
|
})
|
357
387
|
.define_method(
|
358
388
|
"grad",
|
359
389
|
*[](Tensor& self) {
|
360
|
-
|
390
|
+
auto grad = self.grad();
|
391
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
392
|
+
})
|
393
|
+
.define_method(
|
394
|
+
"grad=",
|
395
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
396
|
+
self.grad() = grad;
|
361
397
|
})
|
362
398
|
.define_method(
|
363
399
|
"_dtype",
|
@@ -460,7 +496,7 @@ void Init_ext()
|
|
460
496
|
.define_singleton_method(
|
461
497
|
"_make_subclass",
|
462
498
|
*[](Tensor& rd, bool requires_grad) {
|
463
|
-
auto data =
|
499
|
+
auto data = rd.detach();
|
464
500
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
465
501
|
auto var = data.set_requires_grad(requires_grad);
|
466
502
|
return Parameter(std::move(var));
|
@@ -501,6 +537,7 @@ void Init_ext()
|
|
501
537
|
});
|
502
538
|
|
503
539
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
540
|
+
.add_handler<torch::Error>(handle_error)
|
504
541
|
.define_singleton_method(
|
505
542
|
"_calculate_gain",
|
506
543
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -579,11 +616,16 @@ void Init_ext()
|
|
579
616
|
*[](Parameter& self) {
|
580
617
|
auto grad = self.grad();
|
581
618
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
619
|
+
})
|
620
|
+
.define_method(
|
621
|
+
"grad=",
|
622
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
623
|
+
self.grad() = grad;
|
582
624
|
});
|
583
625
|
|
584
626
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
585
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
586
627
|
.add_handler<torch::Error>(handle_error)
|
628
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
587
629
|
.define_method("index", &torch::Device::index)
|
588
630
|
.define_method("index?", &torch::Device::has_index)
|
589
631
|
.define_method(
|
data/ext/torch/extconf.rb
CHANGED
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
8
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
|
-
|
11
|
-
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
10
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
12
11
|
|
13
12
|
# check omp first
|
14
13
|
if have_library("omp") || have_library("gomp")
|
15
14
|
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
-
$CXXFLAGS += " -Xclang" if
|
15
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
17
16
|
$CXXFLAGS += " -fopenmp"
|
18
17
|
end
|
19
18
|
|
20
|
-
if
|
19
|
+
if apple_clang
|
21
20
|
# silence ruby/intern.h warning
|
22
21
|
$CXXFLAGS += " -Wno-deprecated-register"
|
23
22
|
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,10 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::ScalarType;
|
14
|
+
using torch::Tensor;
|
15
|
+
|
12
16
|
// need to wrap torch::IntArrayRef() since
|
13
17
|
// it doesn't own underlying data
|
14
18
|
class IntArrayRef {
|
@@ -174,8 +178,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
178
|
return MyReduction(x);
|
175
179
|
}
|
176
180
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
181
|
class OptionalTensor {
|
180
182
|
Object value;
|
181
183
|
public:
|
@@ -197,47 +199,28 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
199
|
return OptionalTensor(x);
|
198
200
|
}
|
199
201
|
|
200
|
-
class ScalarType {
|
201
|
-
Object value;
|
202
|
-
public:
|
203
|
-
ScalarType(Object o) {
|
204
|
-
value = o;
|
205
|
-
}
|
206
|
-
operator at::ScalarType() {
|
207
|
-
throw std::runtime_error("ScalarType arguments not implemented yet");
|
208
|
-
}
|
209
|
-
};
|
210
|
-
|
211
202
|
template<>
|
212
203
|
inline
|
213
|
-
ScalarType from_ruby<ScalarType
|
204
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
214
205
|
{
|
215
|
-
|
206
|
+
if (x.is_nil()) {
|
207
|
+
return torch::nullopt;
|
208
|
+
} else {
|
209
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
210
|
+
}
|
216
211
|
}
|
217
212
|
|
218
|
-
class OptionalScalarType {
|
219
|
-
Object value;
|
220
|
-
public:
|
221
|
-
OptionalScalarType(Object o) {
|
222
|
-
value = o;
|
223
|
-
}
|
224
|
-
operator c10::optional<at::ScalarType>() {
|
225
|
-
if (value.is_nil()) {
|
226
|
-
return c10::nullopt;
|
227
|
-
}
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
231
|
-
|
232
213
|
template<>
|
233
214
|
inline
|
234
|
-
|
215
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
235
216
|
{
|
236
|
-
|
217
|
+
if (x.is_nil()) {
|
218
|
+
return torch::nullopt;
|
219
|
+
} else {
|
220
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
221
|
+
}
|
237
222
|
}
|
238
223
|
|
239
|
-
typedef torch::Device Device;
|
240
|
-
|
241
224
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
225
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
226
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -4,6 +4,7 @@ require "torch/ext"
|
|
4
4
|
# stdlib
|
5
5
|
require "fileutils"
|
6
6
|
require "net/http"
|
7
|
+
require "set"
|
7
8
|
require "tmpdir"
|
8
9
|
|
9
10
|
# native functions
|
@@ -178,8 +179,10 @@ require "torch/nn/functional"
|
|
178
179
|
require "torch/nn/init"
|
179
180
|
|
180
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
181
183
|
require "torch/utils/data/data_loader"
|
182
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
183
186
|
require "torch/utils/data/tensor_dataset"
|
184
187
|
|
185
188
|
# hub
|
@@ -315,6 +318,16 @@ module Torch
|
|
315
318
|
end
|
316
319
|
end
|
317
320
|
|
321
|
+
def enable_grad
|
322
|
+
previous_value = grad_enabled?
|
323
|
+
begin
|
324
|
+
_set_grad_enabled(true)
|
325
|
+
yield
|
326
|
+
ensure
|
327
|
+
_set_grad_enabled(previous_value)
|
328
|
+
end
|
329
|
+
end
|
330
|
+
|
318
331
|
def device(str)
|
319
332
|
Device.new(str)
|
320
333
|
end
|
@@ -447,6 +460,22 @@ module Torch
|
|
447
460
|
zeros(input.size, **like_options(input, options))
|
448
461
|
end
|
449
462
|
|
463
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
464
|
+
if center
|
465
|
+
signal_dim = input.dim
|
466
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
467
|
+
pad = n_fft.div(2).to_i
|
468
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
469
|
+
input = input.view(input.shape[-signal_dim..-1])
|
470
|
+
end
|
471
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
472
|
+
end
|
473
|
+
|
474
|
+
def clamp(tensor, min, max)
|
475
|
+
tensor = _clamp_min(tensor, min)
|
476
|
+
_clamp_max(tensor, max)
|
477
|
+
end
|
478
|
+
|
450
479
|
private
|
451
480
|
|
452
481
|
def to_ivalue(obj)
|
@@ -470,11 +499,7 @@ module Torch
|
|
470
499
|
when nil
|
471
500
|
IValue.new
|
472
501
|
when Array
|
473
|
-
|
474
|
-
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
475
|
-
else
|
476
|
-
raise Error, "Unknown list type"
|
477
|
-
end
|
502
|
+
IValue.from_list(obj.map { |v| to_ivalue(v) })
|
478
503
|
else
|
479
504
|
raise Error, "Unknown type: #{obj.class.name}"
|
480
505
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
@@ -1,10 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
+
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
+
@function["func"].sub!(tensor_options_str, ")")
|
8
12
|
end
|
9
13
|
|
10
14
|
def func
|
@@ -33,30 +33,14 @@ module Torch
|
|
33
33
|
f.args.any? do |a|
|
34
34
|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
+
# call to 'range' is ambiguous
|
37
|
+
f.cpp_name == "_range" ||
|
36
38
|
# native_functions.yaml is missing size argument for normal
|
37
39
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
40
|
(f.base_name == "normal" && !f.out?)
|
39
41
|
end
|
40
42
|
end
|
41
43
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
44
|
# todo_functions.each do |f|
|
61
45
|
# puts f.func
|
62
46
|
# puts
|
@@ -97,7 +81,8 @@ void add_%{type}_functions(Module m) {
|
|
97
81
|
|
98
82
|
cpp_defs = []
|
99
83
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
84
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
85
|
+
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
101
86
|
|
102
87
|
cpp_args = []
|
103
88
|
fargs.each do |a|
|
@@ -109,7 +94,7 @@ void add_%{type}_functions(Module m) {
|
|
109
94
|
# TODO better signature
|
110
95
|
"OptionalTensor"
|
111
96
|
when "ScalarType?"
|
112
|
-
"
|
97
|
+
"torch::optional<ScalarType>"
|
113
98
|
when "Tensor[]"
|
114
99
|
"TensorList"
|
115
100
|
when "Tensor?[]"
|
@@ -117,6 +102,8 @@ void add_%{type}_functions(Module m) {
|
|
117
102
|
"TensorList"
|
118
103
|
when "int"
|
119
104
|
"int64_t"
|
105
|
+
when "int?"
|
106
|
+
"torch::optional<int64_t>"
|
120
107
|
when "float"
|
121
108
|
"double"
|
122
109
|
when /\Aint\[/
|
@@ -125,6 +112,8 @@ void add_%{type}_functions(Module m) {
|
|
125
112
|
"Tensor &"
|
126
113
|
when "str"
|
127
114
|
"std::string"
|
115
|
+
when "TensorOptions"
|
116
|
+
"const torch::TensorOptions &"
|
128
117
|
else
|
129
118
|
a[:type]
|
130
119
|
end
|