torch-rb 0.9.0 → 0.9.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -3
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +6 -6
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/tensor.rb +1 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +0 -45
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f5224c74f6e74ed04396dfa0414400af5cb20bc5e654320421116723ffcb8e83
|
4
|
+
data.tar.gz: 6a2881ddacb7610a231ebd5a1c24d0f71a2662f16f360102141dfd0893c13346
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 045432235e1c691ce85fb937a0562d93b1d9bc312fc648d40dafcc24857eeec4f84e6ceab397793171b0046ccfd785a6caeba37902925dcff1f73c760dd57cec
|
7
|
+
data.tar.gz: 2520fa17dcd13be52aaf1256f431d2951f8113f862189b8691f877dcb48ac9f71ac70719600e27d8c8972494c0b7f11c0983c3330fed7ca6aeed552e8500ec22
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -21,10 +21,10 @@ brew install libtorch
|
|
21
21
|
Add this line to your application’s Gemfile:
|
22
22
|
|
23
23
|
```ruby
|
24
|
-
gem
|
24
|
+
gem "torch-rb"
|
25
25
|
```
|
26
26
|
|
27
|
-
It can take
|
27
|
+
It can take 5-10 minutes to compile the extension.
|
28
28
|
|
29
29
|
## Getting Started
|
30
30
|
|
@@ -79,7 +79,7 @@ b = Torch.zeros(2, 3)
|
|
79
79
|
|
80
80
|
Each tensor has four properties
|
81
81
|
|
82
|
-
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`,
|
82
|
+
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `:float64`, or `:bool`
|
83
83
|
- `layout` - `:strided` (dense) or `:sparse`
|
84
84
|
- `device` - the compute device, like CPU or GPU
|
85
85
|
- `requires_grad` - whether or not to record gradients
|
data/ext/torch/backends.cpp
CHANGED
@@ -7,11 +7,11 @@
|
|
7
7
|
void init_backends(Rice::Module& m) {
|
8
8
|
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
9
|
|
10
|
-
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
11
|
.add_handler<torch::Error>(handle_error)
|
12
12
|
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
13
|
|
14
|
-
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
17
|
}
|
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
472
472
|
auto param_idx = find_param(signature, key);
|
473
473
|
if (param_idx < 0) {
|
474
474
|
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
-
signature.name.c_str(),
|
475
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
476
476
|
}
|
477
477
|
|
478
478
|
if (param_idx < num_pos_args) {
|
479
479
|
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
-
signature.name.c_str(),
|
480
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
481
481
|
}
|
482
482
|
}
|
483
483
|
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -235,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
235
235
|
|
236
236
|
auto it = dtype_map.find(args[i]);
|
237
237
|
if (it == dtype_map.end()) {
|
238
|
-
rb_raise(rb_eArgError, "invalid dtype: %s",
|
238
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
239
239
|
}
|
240
240
|
return it->second;
|
241
241
|
}
|
@@ -293,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
293
293
|
|
294
294
|
auto it = layout_map.find(args[i]);
|
295
295
|
if (it == layout_map.end()) {
|
296
|
-
rb_raise(rb_eArgError, "invalid layout: %s",
|
296
|
+
rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
297
297
|
}
|
298
298
|
return it->second;
|
299
299
|
}
|
@@ -325,15 +325,15 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
325
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
326
326
|
}
|
327
327
|
|
328
|
+
// string_view does not own data
|
328
329
|
inline c10::string_view RubyArgs::stringView(int i) {
|
329
|
-
|
330
|
-
return c10::string_view(str.data(), str.size());
|
330
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
331
331
|
}
|
332
332
|
|
333
|
+
// string_view does not own data
|
333
334
|
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
334
335
|
if (NIL_P(args[i])) return c10::nullopt;
|
335
|
-
|
336
|
-
return c10::string_view(str.data(), str.size());
|
336
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
337
337
|
}
|
338
338
|
|
339
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
data/ext/torch/utils.h
CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
|
|
16
16
|
return Rice::Symbol(str);
|
17
17
|
}
|
18
18
|
|
19
|
-
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
20
|
-
Check_Type(obj, T_SYMBOL);
|
21
|
-
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
22
|
-
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
23
|
-
}
|
24
|
-
|
25
19
|
inline std::string THPUtils_unpackString(VALUE obj) {
|
26
20
|
Check_Type(obj, T_STRING);
|
27
21
|
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
data/lib/torch/inspector.rb
CHANGED
@@ -247,7 +247,7 @@ module Torch
|
|
247
247
|
# length includes spaces and comma between elements
|
248
248
|
element_length = formatter.width + 2
|
249
249
|
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
250
|
-
|
250
|
+
_char_per_line = element_length * elements_per_line
|
251
251
|
|
252
252
|
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
253
253
|
data = (
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -571,7 +571,7 @@ module Torch
|
|
571
571
|
end
|
572
572
|
|
573
573
|
def _interp_output_size(closed_over_args)
|
574
|
-
input, size, scale_factor,
|
574
|
+
input, size, scale_factor, _recompute_scale_factor = closed_over_args
|
575
575
|
dim = input.dim - 2
|
576
576
|
if size.nil? && scale_factor.nil?
|
577
577
|
raise ArgumentError, "either size or scale_factor should be defined"
|
data/lib/torch/tensor.rb
CHANGED
@@ -106,6 +106,7 @@ module Torch
|
|
106
106
|
size(0)
|
107
107
|
end
|
108
108
|
|
109
|
+
remove_method :item
|
109
110
|
def item
|
110
111
|
if numel != 1
|
111
112
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
@@ -133,18 +134,10 @@ module Torch
|
|
133
134
|
cls.from_string(_data_str).reshape(*shape)
|
134
135
|
end
|
135
136
|
|
136
|
-
def new_ones(*size, **options)
|
137
|
-
Torch.ones_like(Torch.empty(*size), **options)
|
138
|
-
end
|
139
|
-
|
140
137
|
def requires_grad=(requires_grad)
|
141
138
|
_requires_grad!(requires_grad)
|
142
139
|
end
|
143
140
|
|
144
|
-
def requires_grad!(requires_grad = true)
|
145
|
-
_requires_grad!(requires_grad)
|
146
|
-
end
|
147
|
-
|
148
141
|
def type(dtype)
|
149
142
|
if dtype.is_a?(Class)
|
150
143
|
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
@@ -29,7 +29,7 @@ module Torch
|
|
29
29
|
|
30
30
|
# try to keep the random number generator in sync with Python
|
31
31
|
# this makes it easy to compare results
|
32
|
-
|
32
|
+
_base_seed = Torch.empty([], dtype: :int64).random!.item
|
33
33
|
|
34
34
|
indexes =
|
35
35
|
if @shuffle
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -377,8 +377,6 @@ module Torch
|
|
377
377
|
to_ruby(_load(File.binread(f)))
|
378
378
|
end
|
379
379
|
|
380
|
-
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
381
|
-
|
382
380
|
def tensor(data, **options)
|
383
381
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
384
382
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
@@ -411,41 +409,6 @@ module Torch
|
|
411
409
|
_tensor(data, size, tensor_options(**options))
|
412
410
|
end
|
413
411
|
|
414
|
-
# --- begin like ---
|
415
|
-
|
416
|
-
def ones_like(input, **options)
|
417
|
-
ones(input.size, **like_options(input, options))
|
418
|
-
end
|
419
|
-
|
420
|
-
def empty_like(input, **options)
|
421
|
-
empty(input.size, **like_options(input, options))
|
422
|
-
end
|
423
|
-
|
424
|
-
def full_like(input, fill_value, **options)
|
425
|
-
full(input.size, fill_value, **like_options(input, options))
|
426
|
-
end
|
427
|
-
|
428
|
-
def rand_like(input, **options)
|
429
|
-
rand(input.size, **like_options(input, options))
|
430
|
-
end
|
431
|
-
|
432
|
-
def randint_like(input, low, high = nil, **options)
|
433
|
-
# ruby doesn't support input, low = 0, high, ...
|
434
|
-
if high.nil?
|
435
|
-
high = low
|
436
|
-
low = 0
|
437
|
-
end
|
438
|
-
randint(low, high, input.size, **like_options(input, options))
|
439
|
-
end
|
440
|
-
|
441
|
-
def randn_like(input, **options)
|
442
|
-
randn(input.size, **like_options(input, options))
|
443
|
-
end
|
444
|
-
|
445
|
-
def zeros_like(input, **options)
|
446
|
-
zeros(input.size, **like_options(input, options))
|
447
|
-
end
|
448
|
-
|
449
412
|
# center option
|
450
413
|
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
451
414
|
if center
|
@@ -572,13 +535,5 @@ module Torch
|
|
572
535
|
end
|
573
536
|
options
|
574
537
|
end
|
575
|
-
|
576
|
-
def like_options(input, options)
|
577
|
-
options = options.dup
|
578
|
-
options[:dtype] ||= input.dtype
|
579
|
-
options[:layout] ||= input.layout
|
580
|
-
options[:device] ||= input.device
|
581
|
-
options
|
582
|
-
end
|
583
538
|
end
|
584
539
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.9.
|
4
|
+
version: 0.9.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2022-02-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -227,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
227
|
- !ruby/object:Gem::Version
|
228
228
|
version: '0'
|
229
229
|
requirements: []
|
230
|
-
rubygems_version: 3.
|
230
|
+
rubygems_version: 3.3.3
|
231
231
|
signing_key:
|
232
232
|
specification_version: 4
|
233
233
|
summary: Deep learning for Ruby, powered by LibTorch
|