torch-rb 0.9.0 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '09221364dad232f1b76129fe9dc9407675cc2afbd03bd1339c736d4eec752df7'
4
- data.tar.gz: a37a0584aed809009ebd74e7c0da9430481ccabf1ac915d2468ee7511c249588
3
+ metadata.gz: f5224c74f6e74ed04396dfa0414400af5cb20bc5e654320421116723ffcb8e83
4
+ data.tar.gz: 6a2881ddacb7610a231ebd5a1c24d0f71a2662f16f360102141dfd0893c13346
5
5
  SHA512:
6
- metadata.gz: c220d35971b9ce3e5a7a80f6a5d1ae4324f3524d0e0171c680deced66fe2f29342a46eecb1a4447d84a401a677c7bb1ef910a0c7ee6c925ea4b578b7e5712772
7
- data.tar.gz: 807fe2907de1caac92da6dddb0154b7971dda3aa0ee2c53f6b3046732f4bf3c02310e59a6441efa0fdf1ad3d0ddb5dcd7a3c1a946adbfbea73b6b30f10a71487
6
+ metadata.gz: 045432235e1c691ce85fb937a0562d93b1d9bc312fc648d40dafcc24857eeec4f84e6ceab397793171b0046ccfd785a6caeba37902925dcff1f73c760dd57cec
7
+ data.tar.gz: 2520fa17dcd13be52aaf1256f431d2951f8113f862189b8691f877dcb48ac9f71ac70719600e27d8c8972494c0b7f11c0983c3330fed7ca6aeed552e8500ec22
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.9.1 (2022-02-02)
2
+
3
+ - Moved `like` methods to C++
4
+ - Fixed memory issue
5
+
1
6
  ## 0.9.0 (2021-10-23)
2
7
 
3
8
  - Updated LibTorch to 1.10.0
data/README.md CHANGED
@@ -21,10 +21,10 @@ brew install libtorch
21
21
  Add this line to your application’s Gemfile:
22
22
 
23
23
  ```ruby
24
- gem 'torch-rb'
24
+ gem "torch-rb"
25
25
  ```
26
26
 
27
- It can take a few minutes to compile the extension.
27
+ It can take 5-10 minutes to compile the extension.
28
28
 
29
29
  ## Getting Started
30
30
 
@@ -79,7 +79,7 @@ b = Torch.zeros(2, 3)
79
79
 
80
80
  Each tensor has four properties
81
81
 
82
- - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `float64`, or `:bool`
82
+ - `dtype` - the data type - `:uint8`, `:int8`, `:int16`, `:int32`, `:int64`, `:float32`, `:float64`, or `:bool`
83
83
  - `layout` - `:strided` (dense) or `:sparse`
84
84
  - `device` - the compute device, like CPU or GPU
85
85
  - `requires_grad` - whether or not to record gradients
@@ -7,11 +7,11 @@
7
7
  void init_backends(Rice::Module& m) {
8
8
  auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
9
 
10
- Rice::define_module_under(rb_mBackends, "OpenMP")
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
11
  .add_handler<torch::Error>(handle_error)
12
12
  .define_singleton_function("available?", &torch::hasOpenMP);
13
13
 
14
- Rice::define_module_under(rb_mBackends, "MKL")
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
15
  .add_handler<torch::Error>(handle_error)
16
16
  .define_singleton_function("available?", &torch::hasMKL);
17
17
  }
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
472
472
  auto param_idx = find_param(signature, key);
473
473
  if (param_idx < 0) {
474
474
  rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
475
- signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
475
+ signature.name.c_str(), rb_id2name(rb_to_id(key)));
476
476
  }
477
477
 
478
478
  if (param_idx < num_pos_args) {
479
479
  rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
480
- signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
480
+ signature.name.c_str(), rb_id2name(rb_to_id(key)));
481
481
  }
482
482
  }
483
483
 
@@ -235,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
235
235
 
236
236
  auto it = dtype_map.find(args[i]);
237
237
  if (it == dtype_map.end()) {
238
- rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
238
+ rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
239
239
  }
240
240
  return it->second;
241
241
  }
@@ -293,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
293
293
 
294
294
  auto it = layout_map.find(args[i]);
295
295
  if (it == layout_map.end()) {
296
- rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
296
+ rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
297
297
  }
298
298
  return it->second;
299
299
  }
@@ -325,15 +325,15 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
325
325
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
326
326
  }
327
327
 
328
+ // string_view does not own data
328
329
  inline c10::string_view RubyArgs::stringView(int i) {
329
- auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
330
- return c10::string_view(str.data(), str.size());
330
+ return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
331
331
  }
332
332
 
333
+ // string_view does not own data
333
334
  inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
334
335
  if (NIL_P(args[i])) return c10::nullopt;
335
- auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
336
- return c10::string_view(str.data(), str.size());
336
+ return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
337
337
  }
338
338
 
339
339
  inline int64_t RubyArgs::toInt64(int i) {
data/ext/torch/utils.h CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
16
16
  return Rice::Symbol(str);
17
17
  }
18
18
 
19
- inline std::string THPUtils_unpackSymbol(VALUE obj) {
20
- Check_Type(obj, T_SYMBOL);
21
- obj = rb_funcall(obj, rb_intern("to_s"), 0);
22
- return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
23
- }
24
-
25
19
  inline std::string THPUtils_unpackString(VALUE obj) {
26
20
  Check_Type(obj, T_STRING);
27
21
  return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
@@ -247,7 +247,7 @@ module Torch
247
247
  # length includes spaces and comma between elements
248
248
  element_length = formatter.width + 2
249
249
  elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
250
- char_per_line = element_length * elements_per_line
250
+ _char_per_line = element_length * elements_per_line
251
251
 
252
252
  if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
253
253
  data = (
@@ -571,7 +571,7 @@ module Torch
571
571
  end
572
572
 
573
573
  def _interp_output_size(closed_over_args)
574
- input, size, scale_factor, recompute_scale_factor = closed_over_args
574
+ input, size, scale_factor, _recompute_scale_factor = closed_over_args
575
575
  dim = input.dim - 2
576
576
  if size.nil? && scale_factor.nil?
577
577
  raise ArgumentError, "either size or scale_factor should be defined"
@@ -56,7 +56,7 @@ module Torch
56
56
  attn_mask: nil, dropout_p: 0.0
57
57
  )
58
58
 
59
- b, nt, e = q.shape
59
+ _b, _nt, e = q.shape
60
60
 
61
61
  q = q / Math.sqrt(e)
62
62
 
data/lib/torch/tensor.rb CHANGED
@@ -106,6 +106,7 @@ module Torch
106
106
  size(0)
107
107
  end
108
108
 
109
+ remove_method :item
109
110
  def item
110
111
  if numel != 1
111
112
  raise Error, "only one element tensors can be converted to Ruby scalars"
@@ -133,18 +134,10 @@ module Torch
133
134
  cls.from_string(_data_str).reshape(*shape)
134
135
  end
135
136
 
136
- def new_ones(*size, **options)
137
- Torch.ones_like(Torch.empty(*size), **options)
138
- end
139
-
140
137
  def requires_grad=(requires_grad)
141
138
  _requires_grad!(requires_grad)
142
139
  end
143
140
 
144
- def requires_grad!(requires_grad = true)
145
- _requires_grad!(requires_grad)
146
- end
147
-
148
141
  def type(dtype)
149
142
  if dtype.is_a?(Class)
150
143
  raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
@@ -29,7 +29,7 @@ module Torch
29
29
 
30
30
  # try to keep the random number generator in sync with Python
31
31
  # this makes it easy to compare results
32
- base_seed = Torch.empty([], dtype: :int64).random!.item
32
+ _base_seed = Torch.empty([], dtype: :int64).random!.item
33
33
 
34
34
  indexes =
35
35
  if @shuffle
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.9.0"
2
+ VERSION = "0.9.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -377,8 +377,6 @@ module Torch
377
377
  to_ruby(_load(File.binread(f)))
378
378
  end
379
379
 
380
- # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
381
-
382
380
  def tensor(data, **options)
383
381
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
384
382
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -411,41 +409,6 @@ module Torch
411
409
  _tensor(data, size, tensor_options(**options))
412
410
  end
413
411
 
414
- # --- begin like ---
415
-
416
- def ones_like(input, **options)
417
- ones(input.size, **like_options(input, options))
418
- end
419
-
420
- def empty_like(input, **options)
421
- empty(input.size, **like_options(input, options))
422
- end
423
-
424
- def full_like(input, fill_value, **options)
425
- full(input.size, fill_value, **like_options(input, options))
426
- end
427
-
428
- def rand_like(input, **options)
429
- rand(input.size, **like_options(input, options))
430
- end
431
-
432
- def randint_like(input, low, high = nil, **options)
433
- # ruby doesn't support input, low = 0, high, ...
434
- if high.nil?
435
- high = low
436
- low = 0
437
- end
438
- randint(low, high, input.size, **like_options(input, options))
439
- end
440
-
441
- def randn_like(input, **options)
442
- randn(input.size, **like_options(input, options))
443
- end
444
-
445
- def zeros_like(input, **options)
446
- zeros(input.size, **like_options(input, options))
447
- end
448
-
449
412
  # center option
450
413
  def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
451
414
  if center
@@ -572,13 +535,5 @@ module Torch
572
535
  end
573
536
  options
574
537
  end
575
-
576
- def like_options(input, options)
577
- options = options.dup
578
- options[:dtype] ||= input.dtype
579
- options[:layout] ||= input.layout
580
- options[:device] ||= input.device
581
- options
582
- end
583
538
  end
584
539
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.9.0
4
+ version: 0.9.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-10-23 00:00:00.000000000 Z
11
+ date: 2022-02-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -227,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
227
227
  - !ruby/object:Gem::Version
228
228
  version: '0'
229
229
  requirements: []
230
- rubygems_version: 3.2.22
230
+ rubygems_version: 3.3.3
231
231
  signing_key:
232
232
  specification_version: 4
233
233
  summary: Deep learning for Ruby, powered by LibTorch