torch-rb 0.9.0 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/README.md +6 -4
- data/codegen/function.rb +2 -2
- data/codegen/generate_functions.rb +5 -1
- data/codegen/native_functions.yaml +951 -362
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +6 -6
- data/ext/torch/sparse_functions.h +6 -0
- data/ext/torch/templates.h +34 -0
- data/ext/torch/tensor.cpp +25 -25
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/nn/parameter.rb +3 -0
- data/lib/torch/nn/parameter_list.rb +48 -0
- data/lib/torch/tensor.rb +4 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +15 -48
- metadata +5 -3
data/ext/torch/backends.cpp
CHANGED
@@ -7,11 +7,11 @@
|
|
7
7
|
void init_backends(Rice::Module& m) {
|
8
8
|
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
9
|
|
10
|
-
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
11
|
.add_handler<torch::Error>(handle_error)
|
12
12
|
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
13
|
|
14
|
-
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
17
|
}
|
data/ext/torch/nn.cpp
CHANGED
@@ -98,8 +98,11 @@ void init_nn(Rice::Module& m) {
|
|
98
98
|
auto grad = self.grad();
|
99
99
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
|
+
// can't use grad=
|
102
|
+
// assignment methods fail with Ruby 3.0
|
103
|
+
// TODO add checks like Tensor
|
101
104
|
.define_method(
|
102
|
-
"
|
105
|
+
"_set_grad",
|
103
106
|
[](Parameter& self, torch::Tensor& grad) {
|
104
107
|
self.mutable_grad() = grad;
|
105
108
|
})
|
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
472
472
|
auto param_idx = find_param(signature, key);
|
473
473
|
if (param_idx < 0) {
|
474
474
|
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
-
signature.name.c_str(),
|
475
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
476
476
|
}
|
477
477
|
|
478
478
|
if (param_idx < num_pos_args) {
|
479
479
|
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
-
signature.name.c_str(),
|
480
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
481
481
|
}
|
482
482
|
}
|
483
483
|
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -235,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
235
235
|
|
236
236
|
auto it = dtype_map.find(args[i]);
|
237
237
|
if (it == dtype_map.end()) {
|
238
|
-
rb_raise(rb_eArgError, "invalid dtype: %s",
|
238
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
239
239
|
}
|
240
240
|
return it->second;
|
241
241
|
}
|
@@ -293,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
293
293
|
|
294
294
|
auto it = layout_map.find(args[i]);
|
295
295
|
if (it == layout_map.end()) {
|
296
|
-
rb_raise(rb_eArgError, "invalid layout: %s",
|
296
|
+
rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
297
297
|
}
|
298
298
|
return it->second;
|
299
299
|
}
|
@@ -325,15 +325,15 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
325
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
326
326
|
}
|
327
327
|
|
328
|
+
// string_view does not own data
|
328
329
|
inline c10::string_view RubyArgs::stringView(int i) {
|
329
|
-
|
330
|
-
return c10::string_view(str.data(), str.size());
|
330
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
331
331
|
}
|
332
332
|
|
333
|
+
// string_view does not own data
|
333
334
|
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
334
335
|
if (NIL_P(args[i])) return c10::nullopt;
|
335
|
-
|
336
|
-
return c10::string_view(str.data(), str.size());
|
336
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
337
337
|
}
|
338
338
|
|
339
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
data/ext/torch/templates.h
CHANGED
@@ -41,6 +41,40 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
+
namespace Rice::detail
|
45
|
+
{
|
46
|
+
template<typename T>
|
47
|
+
struct Type<c10::complex<T>>
|
48
|
+
{
|
49
|
+
static bool verify()
|
50
|
+
{
|
51
|
+
return true;
|
52
|
+
}
|
53
|
+
};
|
54
|
+
|
55
|
+
template<typename T>
|
56
|
+
class To_Ruby<c10::complex<T>>
|
57
|
+
{
|
58
|
+
public:
|
59
|
+
VALUE convert(c10::complex<T> const& x)
|
60
|
+
{
|
61
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
62
|
+
}
|
63
|
+
};
|
64
|
+
|
65
|
+
template<typename T>
|
66
|
+
class From_Ruby<c10::complex<T>>
|
67
|
+
{
|
68
|
+
public:
|
69
|
+
c10::complex<T> convert(VALUE x)
|
70
|
+
{
|
71
|
+
VALUE real = rb_funcall(x, rb_intern("real"), 0);
|
72
|
+
VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
|
73
|
+
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
74
|
+
}
|
75
|
+
};
|
76
|
+
}
|
77
|
+
|
44
78
|
namespace Rice::detail
|
45
79
|
{
|
46
80
|
template<>
|
data/ext/torch/tensor.cpp
CHANGED
@@ -10,28 +10,6 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
using torch::indexing::TensorIndex;
|
12
12
|
|
13
|
-
namespace Rice::detail
|
14
|
-
{
|
15
|
-
template<typename T>
|
16
|
-
struct Type<c10::complex<T>>
|
17
|
-
{
|
18
|
-
static bool verify()
|
19
|
-
{
|
20
|
-
return true;
|
21
|
-
}
|
22
|
-
};
|
23
|
-
|
24
|
-
template<typename T>
|
25
|
-
class To_Ruby<c10::complex<T>>
|
26
|
-
{
|
27
|
-
public:
|
28
|
-
VALUE convert(c10::complex<T> const& x)
|
29
|
-
{
|
30
|
-
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
-
}
|
32
|
-
};
|
33
|
-
}
|
34
|
-
|
35
13
|
template<typename T>
|
36
14
|
Array flat_data(Tensor& tensor) {
|
37
15
|
Tensor view = tensor.reshape({tensor.numel()});
|
@@ -189,9 +167,31 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
189
167
|
auto grad = self.grad();
|
190
168
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
191
169
|
})
|
170
|
+
// can't use grad=
|
171
|
+
// assignment methods fail with Ruby 3.0
|
192
172
|
.define_method(
|
193
|
-
"
|
194
|
-
[](Tensor& self,
|
173
|
+
"_set_grad",
|
174
|
+
[](Tensor& self, Rice::Object value) {
|
175
|
+
if (value.is_nil()) {
|
176
|
+
self.mutable_grad().reset();
|
177
|
+
return;
|
178
|
+
}
|
179
|
+
|
180
|
+
const auto& grad = Rice::detail::From_Ruby<torch::Tensor>().convert(value.value());
|
181
|
+
|
182
|
+
// TODO support sparse grad
|
183
|
+
if (!grad.options().type_equal(self.options())) {
|
184
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different type");
|
185
|
+
}
|
186
|
+
|
187
|
+
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
188
|
+
rb_raise(rb_eArgError, "assigned grad has data located on a different device");
|
189
|
+
}
|
190
|
+
|
191
|
+
if (!self.sizes().equals(grad.sizes())) {
|
192
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different size");
|
193
|
+
}
|
194
|
+
|
195
195
|
self.mutable_grad() = grad;
|
196
196
|
})
|
197
197
|
.define_method(
|
@@ -281,7 +281,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
281
281
|
})
|
282
282
|
.define_method(
|
283
283
|
"_to",
|
284
|
-
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
284
|
+
[](Tensor& self, torch::Device& device, int dtype, bool non_blocking, bool copy) {
|
285
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
286
286
|
});
|
287
287
|
|
data/ext/torch/torch.cpp
CHANGED
@@ -6,6 +6,23 @@
|
|
6
6
|
#include "templates.h"
|
7
7
|
#include "utils.h"
|
8
8
|
|
9
|
+
template<typename T>
|
10
|
+
torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
11
|
+
std::vector<T> vec;
|
12
|
+
for (long i = 0; i < a.size(); i++) {
|
13
|
+
vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
|
14
|
+
}
|
15
|
+
|
16
|
+
// hack for requires_grad error
|
17
|
+
auto requires_grad = options.requires_grad();
|
18
|
+
torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
19
|
+
if (requires_grad) {
|
20
|
+
t.set_requires_grad(true);
|
21
|
+
}
|
22
|
+
|
23
|
+
return t.reshape(size);
|
24
|
+
}
|
25
|
+
|
9
26
|
void init_torch(Rice::Module& m) {
|
10
27
|
m.add_handler<torch::Error>(handle_error);
|
11
28
|
add_torch_functions(m);
|
@@ -61,35 +78,28 @@ void init_torch(Rice::Module& m) {
|
|
61
78
|
"_tensor",
|
62
79
|
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
80
|
auto dtype = options.dtype();
|
64
|
-
torch::
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
81
|
+
if (dtype == torch::kByte) {
|
82
|
+
return make_tensor<uint8_t>(a, size, options);
|
83
|
+
} else if (dtype == torch::kChar) {
|
84
|
+
return make_tensor<int8_t>(a, size, options);
|
85
|
+
} else if (dtype == torch::kShort) {
|
86
|
+
return make_tensor<int16_t>(a, size, options);
|
87
|
+
} else if (dtype == torch::kInt) {
|
88
|
+
return make_tensor<int32_t>(a, size, options);
|
89
|
+
} else if (dtype == torch::kLong) {
|
90
|
+
return make_tensor<int64_t>(a, size, options);
|
91
|
+
} else if (dtype == torch::kFloat) {
|
92
|
+
return make_tensor<float>(a, size, options);
|
93
|
+
} else if (dtype == torch::kDouble) {
|
94
|
+
return make_tensor<double>(a, size, options);
|
95
|
+
} else if (dtype == torch::kBool) {
|
96
|
+
return make_tensor<uint8_t>(a, size, options);
|
97
|
+
} else if (dtype == torch::kComplexFloat) {
|
98
|
+
return make_tensor<c10::complex<float>>(a, size, options);
|
99
|
+
} else if (dtype == torch::kComplexDouble) {
|
100
|
+
return make_tensor<c10::complex<double>>(a, size, options);
|
80
101
|
} else {
|
81
|
-
std::
|
82
|
-
for (long i = 0; i < a.size(); i++) {
|
83
|
-
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
84
|
-
}
|
85
|
-
// hack for requires_grad error
|
86
|
-
if (options.requires_grad()) {
|
87
|
-
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
88
|
-
t.set_requires_grad(true);
|
89
|
-
} else {
|
90
|
-
t = torch::tensor(vec, options);
|
91
|
-
}
|
102
|
+
throw std::runtime_error("Unsupported type");
|
92
103
|
}
|
93
|
-
return t.reshape(size);
|
94
104
|
});
|
95
105
|
}
|
data/ext/torch/utils.h
CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
|
|
16
16
|
return Rice::Symbol(str);
|
17
17
|
}
|
18
18
|
|
19
|
-
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
20
|
-
Check_Type(obj, T_SYMBOL);
|
21
|
-
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
22
|
-
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
23
|
-
}
|
24
|
-
|
25
19
|
inline std::string THPUtils_unpackString(VALUE obj) {
|
26
20
|
Check_Type(obj, T_STRING);
|
27
21
|
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
data/lib/torch/inspector.rb
CHANGED
@@ -247,7 +247,7 @@ module Torch
|
|
247
247
|
# length includes spaces and comma between elements
|
248
248
|
element_length = formatter.width + 2
|
249
249
|
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
250
|
-
|
250
|
+
_char_per_line = element_length * elements_per_line
|
251
251
|
|
252
252
|
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
253
253
|
data = (
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -571,7 +571,7 @@ module Torch
|
|
571
571
|
end
|
572
572
|
|
573
573
|
def _interp_output_size(closed_over_args)
|
574
|
-
input, size, scale_factor,
|
574
|
+
input, size, scale_factor, _recompute_scale_factor = closed_over_args
|
575
575
|
dim = input.dim - 2
|
576
576
|
if size.nil? && scale_factor.nil?
|
577
577
|
raise ArgumentError, "either size or scale_factor should be defined"
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,48 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ParameterList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(parameters)
|
7
|
+
super()
|
8
|
+
@initialized = true
|
9
|
+
unless parameters.nil?
|
10
|
+
concat(parameters)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@parameters.length
|
16
|
+
end
|
17
|
+
alias_method :count, :length
|
18
|
+
alias_method :size, :length
|
19
|
+
|
20
|
+
def concat(parameters)
|
21
|
+
unless parameters.is_a?(Enumerable)
|
22
|
+
raise TypeError, "ParameterList#concat should be called with an enumerable, but got #{parameters.class.name}"
|
23
|
+
end
|
24
|
+
offset = length
|
25
|
+
parameters.each_with_index do |param, i|
|
26
|
+
register_parameter((offset + i).to_s, param)
|
27
|
+
end
|
28
|
+
self
|
29
|
+
end
|
30
|
+
|
31
|
+
def each(&block)
|
32
|
+
if block_given?
|
33
|
+
@parameters.values.each(&block)
|
34
|
+
else
|
35
|
+
to_enum(:each)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def [](idx)
|
40
|
+
if idx.is_a?(Range)
|
41
|
+
self.class.new(@parameters.values[idx])
|
42
|
+
else
|
43
|
+
@parameters[idx.to_s]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,9 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# fix for issue w/ assignment methods
|
12
|
+
alias_method :grad=, :_set_grad
|
13
|
+
|
11
14
|
# use alias_method for performance
|
12
15
|
alias_method :+, :add
|
13
16
|
alias_method :-, :sub
|
@@ -106,6 +109,7 @@ module Torch
|
|
106
109
|
size(0)
|
107
110
|
end
|
108
111
|
|
112
|
+
remove_method :item
|
109
113
|
def item
|
110
114
|
if numel != 1
|
111
115
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
@@ -133,18 +137,10 @@ module Torch
|
|
133
137
|
cls.from_string(_data_str).reshape(*shape)
|
134
138
|
end
|
135
139
|
|
136
|
-
def new_ones(*size, **options)
|
137
|
-
Torch.ones_like(Torch.empty(*size), **options)
|
138
|
-
end
|
139
|
-
|
140
140
|
def requires_grad=(requires_grad)
|
141
141
|
_requires_grad!(requires_grad)
|
142
142
|
end
|
143
143
|
|
144
|
-
def requires_grad!(requires_grad = true)
|
145
|
-
_requires_grad!(requires_grad)
|
146
|
-
end
|
147
|
-
|
148
144
|
def type(dtype)
|
149
145
|
if dtype.is_a?(Class)
|
150
146
|
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
@@ -29,7 +29,7 @@ module Torch
|
|
29
29
|
|
30
30
|
# try to keep the random number generator in sync with Python
|
31
31
|
# this makes it easy to compare results
|
32
|
-
|
32
|
+
_base_seed = Torch.empty([], dtype: :int64).random!.item
|
33
33
|
|
34
34
|
indexes =
|
35
35
|
if @shuffle
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -40,6 +40,7 @@ require "torch/nn/utils"
|
|
40
40
|
# nn containers
|
41
41
|
require "torch/nn/module"
|
42
42
|
require "torch/nn/module_list"
|
43
|
+
require "torch/nn/parameter_list"
|
43
44
|
require "torch/nn/sequential"
|
44
45
|
|
45
46
|
# nn convolution layers
|
@@ -267,7 +268,7 @@ module Torch
|
|
267
268
|
args.first.send(dtype).to(device)
|
268
269
|
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
269
270
|
bytes = args.first.bytes
|
270
|
-
Torch.
|
271
|
+
Torch._from_blob_ref(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
271
272
|
elsif args.size == 1 && args.first.is_a?(Array)
|
272
273
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
273
274
|
elsif args.size == 0
|
@@ -320,12 +321,17 @@ module Torch
|
|
320
321
|
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
|
321
322
|
options = tensor_options(device: "cpu", dtype: dtype[0])
|
322
323
|
# TODO pass pointer to array instead of creating string
|
323
|
-
|
324
|
-
|
324
|
+
_from_blob_ref(ndarray.to_string, ndarray.shape, options)
|
325
|
+
end
|
326
|
+
|
327
|
+
# private
|
328
|
+
# TODO use keepAlive in Rice (currently segfaults)
|
329
|
+
def _from_blob_ref(data, size, options)
|
330
|
+
tensor = _from_blob(data, size, options)
|
325
331
|
# from_blob does not own the data, so we need to keep
|
326
332
|
# a reference to it for duration of tensor
|
327
333
|
# can remove when passing pointer directly
|
328
|
-
tensor.instance_variable_set("@
|
334
|
+
tensor.instance_variable_set("@_numo_data", data)
|
329
335
|
tensor
|
330
336
|
end
|
331
337
|
|
@@ -377,8 +383,6 @@ module Torch
|
|
377
383
|
to_ruby(_load(File.binread(f)))
|
378
384
|
end
|
379
385
|
|
380
|
-
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
381
|
-
|
382
386
|
def tensor(data, **options)
|
383
387
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
384
388
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
@@ -408,42 +412,13 @@ module Torch
|
|
408
412
|
end
|
409
413
|
end
|
410
414
|
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
# --- begin like ---
|
415
|
-
|
416
|
-
def ones_like(input, **options)
|
417
|
-
ones(input.size, **like_options(input, options))
|
418
|
-
end
|
419
|
-
|
420
|
-
def empty_like(input, **options)
|
421
|
-
empty(input.size, **like_options(input, options))
|
422
|
-
end
|
415
|
+
# TODO check each dimensions for consistency in future
|
416
|
+
raise Error, "Inconsistent dimensions" if data.size != size.inject(1, :*)
|
423
417
|
|
424
|
-
|
425
|
-
|
426
|
-
end
|
418
|
+
# TOOD move to C++
|
419
|
+
data = data.map { |v| v ? 1 : 0 } if options[:dtype] == :bool
|
427
420
|
|
428
|
-
|
429
|
-
rand(input.size, **like_options(input, options))
|
430
|
-
end
|
431
|
-
|
432
|
-
def randint_like(input, low, high = nil, **options)
|
433
|
-
# ruby doesn't support input, low = 0, high, ...
|
434
|
-
if high.nil?
|
435
|
-
high = low
|
436
|
-
low = 0
|
437
|
-
end
|
438
|
-
randint(low, high, input.size, **like_options(input, options))
|
439
|
-
end
|
440
|
-
|
441
|
-
def randn_like(input, **options)
|
442
|
-
randn(input.size, **like_options(input, options))
|
443
|
-
end
|
444
|
-
|
445
|
-
def zeros_like(input, **options)
|
446
|
-
zeros(input.size, **like_options(input, options))
|
421
|
+
_tensor(data, size, tensor_options(**options))
|
447
422
|
end
|
448
423
|
|
449
424
|
# center option
|
@@ -572,13 +547,5 @@ module Torch
|
|
572
547
|
end
|
573
548
|
options
|
574
549
|
end
|
575
|
-
|
576
|
-
def like_options(input, options)
|
577
|
-
options = options.dup
|
578
|
-
options[:dtype] ||= input.dtype
|
579
|
-
options[:layout] ||= input.layout
|
580
|
-
options[:device] ||= input.device
|
581
|
-
options
|
582
|
-
end
|
583
550
|
end
|
584
551
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.10.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2022-03-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -52,6 +52,7 @@ files:
|
|
52
52
|
- ext/torch/random.cpp
|
53
53
|
- ext/torch/ruby_arg_parser.cpp
|
54
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/sparse_functions.h
|
55
56
|
- ext/torch/special.cpp
|
56
57
|
- ext/torch/special_functions.h
|
57
58
|
- ext/torch/templates.h
|
@@ -149,6 +150,7 @@ files:
|
|
149
150
|
- lib/torch/nn/nll_loss.rb
|
150
151
|
- lib/torch/nn/pairwise_distance.rb
|
151
152
|
- lib/torch/nn/parameter.rb
|
153
|
+
- lib/torch/nn/parameter_list.rb
|
152
154
|
- lib/torch/nn/poisson_nll_loss.rb
|
153
155
|
- lib/torch/nn/prelu.rb
|
154
156
|
- lib/torch/nn/reflection_pad1d.rb
|
@@ -227,7 +229,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
229
|
- !ruby/object:Gem::Version
|
228
230
|
version: '0'
|
229
231
|
requirements: []
|
230
|
-
rubygems_version: 3.
|
232
|
+
rubygems_version: 3.3.7
|
231
233
|
signing_key:
|
232
234
|
specification_version: 4
|
233
235
|
summary: Deep learning for Ruby, powered by LibTorch
|