torch-rb 0.9.0 → 0.10.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/README.md +6 -4
- data/codegen/function.rb +2 -2
- data/codegen/generate_functions.rb +5 -1
- data/codegen/native_functions.yaml +951 -362
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +6 -6
- data/ext/torch/sparse_functions.h +6 -0
- data/ext/torch/templates.h +34 -0
- data/ext/torch/tensor.cpp +25 -25
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/nn/parameter.rb +3 -0
- data/lib/torch/nn/parameter_list.rb +48 -0
- data/lib/torch/tensor.rb +4 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +15 -48
- metadata +5 -3
data/ext/torch/backends.cpp
CHANGED
@@ -7,11 +7,11 @@
|
|
7
7
|
void init_backends(Rice::Module& m) {
|
8
8
|
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
9
|
|
10
|
-
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
11
|
.add_handler<torch::Error>(handle_error)
|
12
12
|
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
13
|
|
14
|
-
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
17
|
}
|
data/ext/torch/nn.cpp
CHANGED
@@ -98,8 +98,11 @@ void init_nn(Rice::Module& m) {
|
|
98
98
|
auto grad = self.grad();
|
99
99
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
|
+
// can't use grad=
|
102
|
+
// assignment methods fail with Ruby 3.0
|
103
|
+
// TODO add checks like Tensor
|
101
104
|
.define_method(
|
102
|
-
"
|
105
|
+
"_set_grad",
|
103
106
|
[](Parameter& self, torch::Tensor& grad) {
|
104
107
|
self.mutable_grad() = grad;
|
105
108
|
})
|
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
472
472
|
auto param_idx = find_param(signature, key);
|
473
473
|
if (param_idx < 0) {
|
474
474
|
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
-
signature.name.c_str(),
|
475
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
476
476
|
}
|
477
477
|
|
478
478
|
if (param_idx < num_pos_args) {
|
479
479
|
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
-
signature.name.c_str(),
|
480
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
481
481
|
}
|
482
482
|
}
|
483
483
|
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -235,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
235
235
|
|
236
236
|
auto it = dtype_map.find(args[i]);
|
237
237
|
if (it == dtype_map.end()) {
|
238
|
-
rb_raise(rb_eArgError, "invalid dtype: %s",
|
238
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
239
239
|
}
|
240
240
|
return it->second;
|
241
241
|
}
|
@@ -293,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
293
293
|
|
294
294
|
auto it = layout_map.find(args[i]);
|
295
295
|
if (it == layout_map.end()) {
|
296
|
-
rb_raise(rb_eArgError, "invalid layout: %s",
|
296
|
+
rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
297
297
|
}
|
298
298
|
return it->second;
|
299
299
|
}
|
@@ -325,15 +325,15 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
325
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
326
326
|
}
|
327
327
|
|
328
|
+
// string_view does not own data
|
328
329
|
inline c10::string_view RubyArgs::stringView(int i) {
|
329
|
-
|
330
|
-
return c10::string_view(str.data(), str.size());
|
330
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
331
331
|
}
|
332
332
|
|
333
|
+
// string_view does not own data
|
333
334
|
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
334
335
|
if (NIL_P(args[i])) return c10::nullopt;
|
335
|
-
|
336
|
-
return c10::string_view(str.data(), str.size());
|
336
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
337
337
|
}
|
338
338
|
|
339
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
data/ext/torch/templates.h
CHANGED
@@ -41,6 +41,40 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
+
namespace Rice::detail
|
45
|
+
{
|
46
|
+
template<typename T>
|
47
|
+
struct Type<c10::complex<T>>
|
48
|
+
{
|
49
|
+
static bool verify()
|
50
|
+
{
|
51
|
+
return true;
|
52
|
+
}
|
53
|
+
};
|
54
|
+
|
55
|
+
template<typename T>
|
56
|
+
class To_Ruby<c10::complex<T>>
|
57
|
+
{
|
58
|
+
public:
|
59
|
+
VALUE convert(c10::complex<T> const& x)
|
60
|
+
{
|
61
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
62
|
+
}
|
63
|
+
};
|
64
|
+
|
65
|
+
template<typename T>
|
66
|
+
class From_Ruby<c10::complex<T>>
|
67
|
+
{
|
68
|
+
public:
|
69
|
+
c10::complex<T> convert(VALUE x)
|
70
|
+
{
|
71
|
+
VALUE real = rb_funcall(x, rb_intern("real"), 0);
|
72
|
+
VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
|
73
|
+
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
74
|
+
}
|
75
|
+
};
|
76
|
+
}
|
77
|
+
|
44
78
|
namespace Rice::detail
|
45
79
|
{
|
46
80
|
template<>
|
data/ext/torch/tensor.cpp
CHANGED
@@ -10,28 +10,6 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
using torch::indexing::TensorIndex;
|
12
12
|
|
13
|
-
namespace Rice::detail
|
14
|
-
{
|
15
|
-
template<typename T>
|
16
|
-
struct Type<c10::complex<T>>
|
17
|
-
{
|
18
|
-
static bool verify()
|
19
|
-
{
|
20
|
-
return true;
|
21
|
-
}
|
22
|
-
};
|
23
|
-
|
24
|
-
template<typename T>
|
25
|
-
class To_Ruby<c10::complex<T>>
|
26
|
-
{
|
27
|
-
public:
|
28
|
-
VALUE convert(c10::complex<T> const& x)
|
29
|
-
{
|
30
|
-
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
-
}
|
32
|
-
};
|
33
|
-
}
|
34
|
-
|
35
13
|
template<typename T>
|
36
14
|
Array flat_data(Tensor& tensor) {
|
37
15
|
Tensor view = tensor.reshape({tensor.numel()});
|
@@ -189,9 +167,31 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
189
167
|
auto grad = self.grad();
|
190
168
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
191
169
|
})
|
170
|
+
// can't use grad=
|
171
|
+
// assignment methods fail with Ruby 3.0
|
192
172
|
.define_method(
|
193
|
-
"
|
194
|
-
[](Tensor& self,
|
173
|
+
"_set_grad",
|
174
|
+
[](Tensor& self, Rice::Object value) {
|
175
|
+
if (value.is_nil()) {
|
176
|
+
self.mutable_grad().reset();
|
177
|
+
return;
|
178
|
+
}
|
179
|
+
|
180
|
+
const auto& grad = Rice::detail::From_Ruby<torch::Tensor>().convert(value.value());
|
181
|
+
|
182
|
+
// TODO support sparse grad
|
183
|
+
if (!grad.options().type_equal(self.options())) {
|
184
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different type");
|
185
|
+
}
|
186
|
+
|
187
|
+
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
188
|
+
rb_raise(rb_eArgError, "assigned grad has data located on a different device");
|
189
|
+
}
|
190
|
+
|
191
|
+
if (!self.sizes().equals(grad.sizes())) {
|
192
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different size");
|
193
|
+
}
|
194
|
+
|
195
195
|
self.mutable_grad() = grad;
|
196
196
|
})
|
197
197
|
.define_method(
|
@@ -281,7 +281,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
281
281
|
})
|
282
282
|
.define_method(
|
283
283
|
"_to",
|
284
|
-
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
284
|
+
[](Tensor& self, torch::Device& device, int dtype, bool non_blocking, bool copy) {
|
285
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
286
286
|
});
|
287
287
|
|
data/ext/torch/torch.cpp
CHANGED
@@ -6,6 +6,23 @@
|
|
6
6
|
#include "templates.h"
|
7
7
|
#include "utils.h"
|
8
8
|
|
9
|
+
template<typename T>
|
10
|
+
torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
11
|
+
std::vector<T> vec;
|
12
|
+
for (long i = 0; i < a.size(); i++) {
|
13
|
+
vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
|
14
|
+
}
|
15
|
+
|
16
|
+
// hack for requires_grad error
|
17
|
+
auto requires_grad = options.requires_grad();
|
18
|
+
torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
19
|
+
if (requires_grad) {
|
20
|
+
t.set_requires_grad(true);
|
21
|
+
}
|
22
|
+
|
23
|
+
return t.reshape(size);
|
24
|
+
}
|
25
|
+
|
9
26
|
void init_torch(Rice::Module& m) {
|
10
27
|
m.add_handler<torch::Error>(handle_error);
|
11
28
|
add_torch_functions(m);
|
@@ -61,35 +78,28 @@ void init_torch(Rice::Module& m) {
|
|
61
78
|
"_tensor",
|
62
79
|
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
80
|
auto dtype = options.dtype();
|
64
|
-
torch::
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
81
|
+
if (dtype == torch::kByte) {
|
82
|
+
return make_tensor<uint8_t>(a, size, options);
|
83
|
+
} else if (dtype == torch::kChar) {
|
84
|
+
return make_tensor<int8_t>(a, size, options);
|
85
|
+
} else if (dtype == torch::kShort) {
|
86
|
+
return make_tensor<int16_t>(a, size, options);
|
87
|
+
} else if (dtype == torch::kInt) {
|
88
|
+
return make_tensor<int32_t>(a, size, options);
|
89
|
+
} else if (dtype == torch::kLong) {
|
90
|
+
return make_tensor<int64_t>(a, size, options);
|
91
|
+
} else if (dtype == torch::kFloat) {
|
92
|
+
return make_tensor<float>(a, size, options);
|
93
|
+
} else if (dtype == torch::kDouble) {
|
94
|
+
return make_tensor<double>(a, size, options);
|
95
|
+
} else if (dtype == torch::kBool) {
|
96
|
+
return make_tensor<uint8_t>(a, size, options);
|
97
|
+
} else if (dtype == torch::kComplexFloat) {
|
98
|
+
return make_tensor<c10::complex<float>>(a, size, options);
|
99
|
+
} else if (dtype == torch::kComplexDouble) {
|
100
|
+
return make_tensor<c10::complex<double>>(a, size, options);
|
80
101
|
} else {
|
81
|
-
std::
|
82
|
-
for (long i = 0; i < a.size(); i++) {
|
83
|
-
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
84
|
-
}
|
85
|
-
// hack for requires_grad error
|
86
|
-
if (options.requires_grad()) {
|
87
|
-
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
88
|
-
t.set_requires_grad(true);
|
89
|
-
} else {
|
90
|
-
t = torch::tensor(vec, options);
|
91
|
-
}
|
102
|
+
throw std::runtime_error("Unsupported type");
|
92
103
|
}
|
93
|
-
return t.reshape(size);
|
94
104
|
});
|
95
105
|
}
|
data/ext/torch/utils.h
CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
|
|
16
16
|
return Rice::Symbol(str);
|
17
17
|
}
|
18
18
|
|
19
|
-
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
20
|
-
Check_Type(obj, T_SYMBOL);
|
21
|
-
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
22
|
-
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
23
|
-
}
|
24
|
-
|
25
19
|
inline std::string THPUtils_unpackString(VALUE obj) {
|
26
20
|
Check_Type(obj, T_STRING);
|
27
21
|
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
data/lib/torch/inspector.rb
CHANGED
@@ -247,7 +247,7 @@ module Torch
|
|
247
247
|
# length includes spaces and comma between elements
|
248
248
|
element_length = formatter.width + 2
|
249
249
|
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
250
|
-
|
250
|
+
_char_per_line = element_length * elements_per_line
|
251
251
|
|
252
252
|
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
253
253
|
data = (
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -571,7 +571,7 @@ module Torch
|
|
571
571
|
end
|
572
572
|
|
573
573
|
def _interp_output_size(closed_over_args)
|
574
|
-
input, size, scale_factor,
|
574
|
+
input, size, scale_factor, _recompute_scale_factor = closed_over_args
|
575
575
|
dim = input.dim - 2
|
576
576
|
if size.nil? && scale_factor.nil?
|
577
577
|
raise ArgumentError, "either size or scale_factor should be defined"
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,48 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ParameterList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(parameters)
|
7
|
+
super()
|
8
|
+
@initialized = true
|
9
|
+
unless parameters.nil?
|
10
|
+
concat(parameters)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@parameters.length
|
16
|
+
end
|
17
|
+
alias_method :count, :length
|
18
|
+
alias_method :size, :length
|
19
|
+
|
20
|
+
def concat(parameters)
|
21
|
+
unless parameters.is_a?(Enumerable)
|
22
|
+
raise TypeError, "ParameterList#concat should be called with an enumerable, but got #{parameters.class.name}"
|
23
|
+
end
|
24
|
+
offset = length
|
25
|
+
parameters.each_with_index do |param, i|
|
26
|
+
register_parameter((offset + i).to_s, param)
|
27
|
+
end
|
28
|
+
self
|
29
|
+
end
|
30
|
+
|
31
|
+
def each(&block)
|
32
|
+
if block_given?
|
33
|
+
@parameters.values.each(&block)
|
34
|
+
else
|
35
|
+
to_enum(:each)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def [](idx)
|
40
|
+
if idx.is_a?(Range)
|
41
|
+
self.class.new(@parameters.values[idx])
|
42
|
+
else
|
43
|
+
@parameters[idx.to_s]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,9 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# fix for issue w/ assignment methods
|
12
|
+
alias_method :grad=, :_set_grad
|
13
|
+
|
11
14
|
# use alias_method for performance
|
12
15
|
alias_method :+, :add
|
13
16
|
alias_method :-, :sub
|
@@ -106,6 +109,7 @@ module Torch
|
|
106
109
|
size(0)
|
107
110
|
end
|
108
111
|
|
112
|
+
remove_method :item
|
109
113
|
def item
|
110
114
|
if numel != 1
|
111
115
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
@@ -133,18 +137,10 @@ module Torch
|
|
133
137
|
cls.from_string(_data_str).reshape(*shape)
|
134
138
|
end
|
135
139
|
|
136
|
-
def new_ones(*size, **options)
|
137
|
-
Torch.ones_like(Torch.empty(*size), **options)
|
138
|
-
end
|
139
|
-
|
140
140
|
def requires_grad=(requires_grad)
|
141
141
|
_requires_grad!(requires_grad)
|
142
142
|
end
|
143
143
|
|
144
|
-
def requires_grad!(requires_grad = true)
|
145
|
-
_requires_grad!(requires_grad)
|
146
|
-
end
|
147
|
-
|
148
144
|
def type(dtype)
|
149
145
|
if dtype.is_a?(Class)
|
150
146
|
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
@@ -29,7 +29,7 @@ module Torch
|
|
29
29
|
|
30
30
|
# try to keep the random number generator in sync with Python
|
31
31
|
# this makes it easy to compare results
|
32
|
-
|
32
|
+
_base_seed = Torch.empty([], dtype: :int64).random!.item
|
33
33
|
|
34
34
|
indexes =
|
35
35
|
if @shuffle
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -40,6 +40,7 @@ require "torch/nn/utils"
|
|
40
40
|
# nn containers
|
41
41
|
require "torch/nn/module"
|
42
42
|
require "torch/nn/module_list"
|
43
|
+
require "torch/nn/parameter_list"
|
43
44
|
require "torch/nn/sequential"
|
44
45
|
|
45
46
|
# nn convolution layers
|
@@ -267,7 +268,7 @@ module Torch
|
|
267
268
|
args.first.send(dtype).to(device)
|
268
269
|
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
269
270
|
bytes = args.first.bytes
|
270
|
-
Torch.
|
271
|
+
Torch._from_blob_ref(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
271
272
|
elsif args.size == 1 && args.first.is_a?(Array)
|
272
273
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
273
274
|
elsif args.size == 0
|
@@ -320,12 +321,17 @@ module Torch
|
|
320
321
|
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
|
321
322
|
options = tensor_options(device: "cpu", dtype: dtype[0])
|
322
323
|
# TODO pass pointer to array instead of creating string
|
323
|
-
|
324
|
-
|
324
|
+
_from_blob_ref(ndarray.to_string, ndarray.shape, options)
|
325
|
+
end
|
326
|
+
|
327
|
+
# private
|
328
|
+
# TODO use keepAlive in Rice (currently segfaults)
|
329
|
+
def _from_blob_ref(data, size, options)
|
330
|
+
tensor = _from_blob(data, size, options)
|
325
331
|
# from_blob does not own the data, so we need to keep
|
326
332
|
# a reference to it for duration of tensor
|
327
333
|
# can remove when passing pointer directly
|
328
|
-
tensor.instance_variable_set("@
|
334
|
+
tensor.instance_variable_set("@_numo_data", data)
|
329
335
|
tensor
|
330
336
|
end
|
331
337
|
|
@@ -377,8 +383,6 @@ module Torch
|
|
377
383
|
to_ruby(_load(File.binread(f)))
|
378
384
|
end
|
379
385
|
|
380
|
-
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
381
|
-
|
382
386
|
def tensor(data, **options)
|
383
387
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
384
388
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
@@ -408,42 +412,13 @@ module Torch
|
|
408
412
|
end
|
409
413
|
end
|
410
414
|
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
# --- begin like ---
|
415
|
-
|
416
|
-
def ones_like(input, **options)
|
417
|
-
ones(input.size, **like_options(input, options))
|
418
|
-
end
|
419
|
-
|
420
|
-
def empty_like(input, **options)
|
421
|
-
empty(input.size, **like_options(input, options))
|
422
|
-
end
|
415
|
+
# TODO check each dimensions for consistency in future
|
416
|
+
raise Error, "Inconsistent dimensions" if data.size != size.inject(1, :*)
|
423
417
|
|
424
|
-
|
425
|
-
|
426
|
-
end
|
418
|
+
# TOOD move to C++
|
419
|
+
data = data.map { |v| v ? 1 : 0 } if options[:dtype] == :bool
|
427
420
|
|
428
|
-
|
429
|
-
rand(input.size, **like_options(input, options))
|
430
|
-
end
|
431
|
-
|
432
|
-
def randint_like(input, low, high = nil, **options)
|
433
|
-
# ruby doesn't support input, low = 0, high, ...
|
434
|
-
if high.nil?
|
435
|
-
high = low
|
436
|
-
low = 0
|
437
|
-
end
|
438
|
-
randint(low, high, input.size, **like_options(input, options))
|
439
|
-
end
|
440
|
-
|
441
|
-
def randn_like(input, **options)
|
442
|
-
randn(input.size, **like_options(input, options))
|
443
|
-
end
|
444
|
-
|
445
|
-
def zeros_like(input, **options)
|
446
|
-
zeros(input.size, **like_options(input, options))
|
421
|
+
_tensor(data, size, tensor_options(**options))
|
447
422
|
end
|
448
423
|
|
449
424
|
# center option
|
@@ -572,13 +547,5 @@ module Torch
|
|
572
547
|
end
|
573
548
|
options
|
574
549
|
end
|
575
|
-
|
576
|
-
def like_options(input, options)
|
577
|
-
options = options.dup
|
578
|
-
options[:dtype] ||= input.dtype
|
579
|
-
options[:layout] ||= input.layout
|
580
|
-
options[:device] ||= input.device
|
581
|
-
options
|
582
|
-
end
|
583
550
|
end
|
584
551
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.10.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2022-03-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -52,6 +52,7 @@ files:
|
|
52
52
|
- ext/torch/random.cpp
|
53
53
|
- ext/torch/ruby_arg_parser.cpp
|
54
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/sparse_functions.h
|
55
56
|
- ext/torch/special.cpp
|
56
57
|
- ext/torch/special_functions.h
|
57
58
|
- ext/torch/templates.h
|
@@ -149,6 +150,7 @@ files:
|
|
149
150
|
- lib/torch/nn/nll_loss.rb
|
150
151
|
- lib/torch/nn/pairwise_distance.rb
|
151
152
|
- lib/torch/nn/parameter.rb
|
153
|
+
- lib/torch/nn/parameter_list.rb
|
152
154
|
- lib/torch/nn/poisson_nll_loss.rb
|
153
155
|
- lib/torch/nn/prelu.rb
|
154
156
|
- lib/torch/nn/reflection_pad1d.rb
|
@@ -227,7 +229,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
229
|
- !ruby/object:Gem::Version
|
228
230
|
version: '0'
|
229
231
|
requirements: []
|
230
|
-
rubygems_version: 3.
|
232
|
+
rubygems_version: 3.3.7
|
231
233
|
signing_key:
|
232
234
|
specification_version: 4
|
233
235
|
summary: Deep learning for Ruby, powered by LibTorch
|