torch-rb 0.9.0 → 0.9.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -3
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +6 -6
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/tensor.rb +1 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +0 -45
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f5224c74f6e74ed04396dfa0414400af5cb20bc5e654320421116723ffcb8e83
|
4
|
+
data.tar.gz: 6a2881ddacb7610a231ebd5a1c24d0f71a2662f16f360102141dfd0893c13346
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 045432235e1c691ce85fb937a0562d93b1d9bc312fc648d40dafcc24857eeec4f84e6ceab397793171b0046ccfd785a6caeba37902925dcff1f73c760dd57cec
|
7
|
+
data.tar.gz: 2520fa17dcd13be52aaf1256f431d2951f8113f862189b8691f877dcb48ac9f71ac70719600e27d8c8972494c0b7f11c0983c3330fed7ca6aeed552e8500ec22
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -21,10 +21,10 @@ brew install libtorch
|
|
21
21
|
Add this line to your application’s Gemfile:
|
22
22
|
|
23
23
|
```ruby
|
24
|
-
gem
|
24
|
+
gem "torch-rb"
|
25
25
|
```
|
26
26
|
|
27
|
-
It can take
|
27
|
+
It can take 5-10 minutes to compile the extension.
|
28
28
|
|
29
29
|
## Getting Started
|
30
30
|
|
@@ -79,7 +79,7 @@ b = Torch.zeros(2, 3)
|
|
79
79
|
|
80
80
|
Each tensor has four properties
|
81
81
|
|
82
|
-
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`,
|
82
|
+
- `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `:float64`, or `:bool`
|
83
83
|
- `layout` - `:strided` (dense) or `:sparse`
|
84
84
|
- `device` - the compute device, like CPU or GPU
|
85
85
|
- `requires_grad` - whether or not to record gradients
|
data/ext/torch/backends.cpp
CHANGED
@@ -7,11 +7,11 @@
|
|
7
7
|
void init_backends(Rice::Module& m) {
|
8
8
|
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
9
|
|
10
|
-
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
11
|
.add_handler<torch::Error>(handle_error)
|
12
12
|
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
13
|
|
14
|
-
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
17
|
}
|
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
472
472
|
auto param_idx = find_param(signature, key);
|
473
473
|
if (param_idx < 0) {
|
474
474
|
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
-
signature.name.c_str(),
|
475
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
476
476
|
}
|
477
477
|
|
478
478
|
if (param_idx < num_pos_args) {
|
479
479
|
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
-
signature.name.c_str(),
|
480
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
481
481
|
}
|
482
482
|
}
|
483
483
|
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -235,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
235
235
|
|
236
236
|
auto it = dtype_map.find(args[i]);
|
237
237
|
if (it == dtype_map.end()) {
|
238
|
-
rb_raise(rb_eArgError, "invalid dtype: %s",
|
238
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
239
239
|
}
|
240
240
|
return it->second;
|
241
241
|
}
|
@@ -293,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
293
293
|
|
294
294
|
auto it = layout_map.find(args[i]);
|
295
295
|
if (it == layout_map.end()) {
|
296
|
-
rb_raise(rb_eArgError, "invalid layout: %s",
|
296
|
+
rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
297
297
|
}
|
298
298
|
return it->second;
|
299
299
|
}
|
@@ -325,15 +325,15 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
325
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
326
326
|
}
|
327
327
|
|
328
|
+
// string_view does not own data
|
328
329
|
inline c10::string_view RubyArgs::stringView(int i) {
|
329
|
-
|
330
|
-
return c10::string_view(str.data(), str.size());
|
330
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
331
331
|
}
|
332
332
|
|
333
|
+
// string_view does not own data
|
333
334
|
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
334
335
|
if (NIL_P(args[i])) return c10::nullopt;
|
335
|
-
|
336
|
-
return c10::string_view(str.data(), str.size());
|
336
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
337
337
|
}
|
338
338
|
|
339
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
data/ext/torch/utils.h
CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
|
|
16
16
|
return Rice::Symbol(str);
|
17
17
|
}
|
18
18
|
|
19
|
-
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
20
|
-
Check_Type(obj, T_SYMBOL);
|
21
|
-
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
22
|
-
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
23
|
-
}
|
24
|
-
|
25
19
|
inline std::string THPUtils_unpackString(VALUE obj) {
|
26
20
|
Check_Type(obj, T_STRING);
|
27
21
|
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
data/lib/torch/inspector.rb
CHANGED
@@ -247,7 +247,7 @@ module Torch
|
|
247
247
|
# length includes spaces and comma between elements
|
248
248
|
element_length = formatter.width + 2
|
249
249
|
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
250
|
-
|
250
|
+
_char_per_line = element_length * elements_per_line
|
251
251
|
|
252
252
|
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
253
253
|
data = (
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -571,7 +571,7 @@ module Torch
|
|
571
571
|
end
|
572
572
|
|
573
573
|
def _interp_output_size(closed_over_args)
|
574
|
-
input, size, scale_factor,
|
574
|
+
input, size, scale_factor, _recompute_scale_factor = closed_over_args
|
575
575
|
dim = input.dim - 2
|
576
576
|
if size.nil? && scale_factor.nil?
|
577
577
|
raise ArgumentError, "either size or scale_factor should be defined"
|
data/lib/torch/tensor.rb
CHANGED
@@ -106,6 +106,7 @@ module Torch
|
|
106
106
|
size(0)
|
107
107
|
end
|
108
108
|
|
109
|
+
remove_method :item
|
109
110
|
def item
|
110
111
|
if numel != 1
|
111
112
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
@@ -133,18 +134,10 @@ module Torch
|
|
133
134
|
cls.from_string(_data_str).reshape(*shape)
|
134
135
|
end
|
135
136
|
|
136
|
-
def new_ones(*size, **options)
|
137
|
-
Torch.ones_like(Torch.empty(*size), **options)
|
138
|
-
end
|
139
|
-
|
140
137
|
def requires_grad=(requires_grad)
|
141
138
|
_requires_grad!(requires_grad)
|
142
139
|
end
|
143
140
|
|
144
|
-
def requires_grad!(requires_grad = true)
|
145
|
-
_requires_grad!(requires_grad)
|
146
|
-
end
|
147
|
-
|
148
141
|
def type(dtype)
|
149
142
|
if dtype.is_a?(Class)
|
150
143
|
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
@@ -29,7 +29,7 @@ module Torch
|
|
29
29
|
|
30
30
|
# try to keep the random number generator in sync with Python
|
31
31
|
# this makes it easy to compare results
|
32
|
-
|
32
|
+
_base_seed = Torch.empty([], dtype: :int64).random!.item
|
33
33
|
|
34
34
|
indexes =
|
35
35
|
if @shuffle
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -377,8 +377,6 @@ module Torch
|
|
377
377
|
to_ruby(_load(File.binread(f)))
|
378
378
|
end
|
379
379
|
|
380
|
-
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
381
|
-
|
382
380
|
def tensor(data, **options)
|
383
381
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
384
382
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
@@ -411,41 +409,6 @@ module Torch
|
|
411
409
|
_tensor(data, size, tensor_options(**options))
|
412
410
|
end
|
413
411
|
|
414
|
-
# --- begin like ---
|
415
|
-
|
416
|
-
def ones_like(input, **options)
|
417
|
-
ones(input.size, **like_options(input, options))
|
418
|
-
end
|
419
|
-
|
420
|
-
def empty_like(input, **options)
|
421
|
-
empty(input.size, **like_options(input, options))
|
422
|
-
end
|
423
|
-
|
424
|
-
def full_like(input, fill_value, **options)
|
425
|
-
full(input.size, fill_value, **like_options(input, options))
|
426
|
-
end
|
427
|
-
|
428
|
-
def rand_like(input, **options)
|
429
|
-
rand(input.size, **like_options(input, options))
|
430
|
-
end
|
431
|
-
|
432
|
-
def randint_like(input, low, high = nil, **options)
|
433
|
-
# ruby doesn't support input, low = 0, high, ...
|
434
|
-
if high.nil?
|
435
|
-
high = low
|
436
|
-
low = 0
|
437
|
-
end
|
438
|
-
randint(low, high, input.size, **like_options(input, options))
|
439
|
-
end
|
440
|
-
|
441
|
-
def randn_like(input, **options)
|
442
|
-
randn(input.size, **like_options(input, options))
|
443
|
-
end
|
444
|
-
|
445
|
-
def zeros_like(input, **options)
|
446
|
-
zeros(input.size, **like_options(input, options))
|
447
|
-
end
|
448
|
-
|
449
412
|
# center option
|
450
413
|
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
451
414
|
if center
|
@@ -572,13 +535,5 @@ module Torch
|
|
572
535
|
end
|
573
536
|
options
|
574
537
|
end
|
575
|
-
|
576
|
-
def like_options(input, options)
|
577
|
-
options = options.dup
|
578
|
-
options[:dtype] ||= input.dtype
|
579
|
-
options[:layout] ||= input.layout
|
580
|
-
options[:device] ||= input.device
|
581
|
-
options
|
582
|
-
end
|
583
538
|
end
|
584
539
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.9.
|
4
|
+
version: 0.9.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2022-02-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -227,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
227
|
- !ruby/object:Gem::Version
|
228
228
|
version: '0'
|
229
229
|
requirements: []
|
230
|
-
rubygems_version: 3.
|
230
|
+
rubygems_version: 3.3.3
|
231
231
|
signing_key:
|
232
232
|
specification_version: 4
|
233
233
|
summary: Deep learning for Ruby, powered by LibTorch
|