torch-rb 0.4.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -348,16 +348,6 @@ void Init_ext()
348
348
  *[](Tensor& self) {
349
349
  return self.is_contiguous();
350
350
  })
351
- .define_method(
352
- "addcmul!",
353
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
354
- return self.addcmul_(tensor1, tensor2, value);
355
- })
356
- .define_method(
357
- "addcdiv!",
358
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
359
- return self.addcdiv_(tensor1, tensor2, value);
360
- })
361
351
  .define_method(
362
352
  "_requires_grad!",
363
353
  *[](Tensor& self, bool requires_grad) {
@@ -372,7 +362,7 @@ void Init_ext()
372
362
  .define_method(
373
363
  "grad=",
374
364
  *[](Tensor& self, torch::Tensor& grad) {
375
- self.grad() = grad;
365
+ self.mutable_grad() = grad;
376
366
  })
377
367
  .define_method(
378
368
  "_dtype",
@@ -609,7 +599,7 @@ void Init_ext()
609
599
  .define_method(
610
600
  "grad=",
611
601
  *[](Parameter& self, torch::Tensor& grad) {
612
- self.grad() = grad;
602
+ self.mutable_grad() = grad;
613
603
  });
614
604
 
615
605
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
@@ -91,7 +91,7 @@ struct RubyArgs {
91
91
  inline c10::optional<int64_t> toInt64Optional(int i);
92
92
  inline c10::optional<bool> toBoolOptional(int i);
93
93
  inline c10::optional<double> toDoubleOptional(int i);
94
- // inline c10::OptionalArray<double> doublelistOptional(int i);
94
+ inline c10::OptionalArray<double> doublelistOptional(int i);
95
95
  // inline at::Layout layout(int i);
96
96
  // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
97
97
  inline c10::optional<at::Layout> layoutOptional(int i);
@@ -105,7 +105,7 @@ struct RubyArgs {
105
105
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
106
106
  // inline at::QScheme toQScheme(int i);
107
107
  inline std::string string(int i);
108
- // inline c10::optional<std::string> stringOptional(int i);
108
+ inline c10::optional<std::string> stringOptional(int i);
109
109
  // inline PyObject* pyobject(int i);
110
110
  inline int64_t toInt64(int i);
111
111
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -249,6 +249,25 @@ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
249
249
  return toDouble(i);
250
250
  }
251
251
 
252
+ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
253
+ if (NIL_P(args[i])) return {};
254
+
255
+ VALUE arg = args[i];
256
+ auto size = RARRAY_LEN(arg);
257
+ std::vector<double> res(size);
258
+ for (idx = 0; idx < size; idx++) {
259
+ VALUE obj = rb_ary_entry(arg, idx);
260
+ if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
261
+ res[idx] = from_ruby<double>(obj);
262
+ } else {
263
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
264
+ signature.name.c_str(), signature.params[i].name.c_str(),
265
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
266
+ }
267
+ }
268
+ return res;
269
+ }
270
+
252
271
  inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
253
272
  if (NIL_P(args[i])) return c10::nullopt;
254
273
 
@@ -285,6 +304,11 @@ inline std::string RubyArgs::string(int i) {
285
304
  return from_ruby<std::string>(args[i]);
286
305
  }
287
306
 
307
+ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
308
+ if (!args[i]) return c10::nullopt;
309
+ return from_ruby<std::string>(args[i]);
310
+ }
311
+
288
312
  inline int64_t RubyArgs::toInt64(int i) {
289
313
  if (NIL_P(args[i])) return signature.params[i].default_int;
290
314
  return from_ruby<int64_t>(args[i]);
@@ -19,6 +19,7 @@ using torch::TensorOptions;
19
19
  using torch::Layout;
20
20
  using torch::MemoryFormat;
21
21
  using torch::IntArrayRef;
22
+ using torch::ArrayRef;
22
23
  using torch::TensorList;
23
24
  using torch::Storage;
24
25
 
@@ -90,3 +90,10 @@ inline Object wrap(torch::TensorList x) {
90
90
  }
91
91
  return Object(a);
92
92
  }
93
+
94
+ inline Object wrap(std::tuple<double, double> x) {
95
+ Array a;
96
+ a.push(to_ruby<double>(std::get<0>(x)));
97
+ a.push(to_ruby<double>(std::get<1>(x)));
98
+ return Object(a);
99
+ }
@@ -261,6 +261,8 @@ module Torch
261
261
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
262
262
  elsif args.size == 1 && args.first.is_a?(Array)
263
263
  Torch.tensor(args.first, dtype: dtype, device: device)
264
+ elsif args.size == 0
265
+ Torch.empty(0, dtype: dtype, device: device)
264
266
  else
265
267
  Torch.empty(*args, dtype: dtype, device: device)
266
268
  end
@@ -434,7 +436,8 @@ module Torch
434
436
  zeros(input.size, **like_options(input, options))
435
437
  end
436
438
 
437
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
439
+ # center option
440
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
438
441
  if center
439
442
  signal_dim = input.dim
440
443
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -442,12 +445,7 @@ module Torch
442
445
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
443
446
  input = input.view(input.shape[-signal_dim..-1])
444
447
  end
445
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
446
- end
447
-
448
- def clamp(tensor, min, max)
449
- tensor = _clamp_min(tensor, min)
450
- _clamp_max(tensor, max)
448
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
451
449
  end
452
450
 
453
451
  private
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
48
  p.data.add!(delta, alpha: -group[:lr])
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end
@@ -58,7 +58,7 @@ module Torch
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
60
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,7 +42,7 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -59,7 +59,7 @@ module Torch
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
61
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,7 +36,7 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end
@@ -52,7 +52,7 @@ module Torch
52
52
  grad[sign.eq(etaminus)] = 0
53
53
 
54
54
  # update parameters
55
- p.data.addcmul!(-1, grad.sign, step_size)
55
+ p.data.addcmul!(grad.sign, step_size, value: -1)
56
56
 
57
57
  state[:prev].copy!(grad)
58
58
  end
@@ -36,14 +36,14 @@ module Torch
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
39
- if !param_state.key(:momentum_buffer)
39
+ if !param_state.key?(:momentum_buffer)
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
43
  buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
- d_p = d_p.add(momentum, buf)
46
+ d_p = d_p.add(buf, alpha: momentum)
47
47
  else
48
48
  d_p = buf
49
49
  end
@@ -174,5 +174,10 @@ module Torch
174
174
  return _random!(0, *args) if args.size == 1
175
175
  _random!(*args)
176
176
  end
177
+
178
+ # center option
179
+ def stft(*args)
180
+ Torch.stft(*args)
181
+ end
177
182
  end
178
183
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.4.0"
2
+ VERSION = "0.5.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.5.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-27 00:00:00.000000000 Z
11
+ date: 2020-10-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,21 +108,7 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
- - !ruby/object:Gem::Dependency
112
- name: magro
113
- requirement: !ruby/object:Gem::Requirement
114
- requirements:
115
- - - ">="
116
- - !ruby/object:Gem::Version
117
- version: '0'
118
- type: :development
119
- prerelease: false
120
- version_requirements: !ruby/object:Gem::Requirement
121
- requirements:
122
- - - ">="
123
- - !ruby/object:Gem::Version
124
- version: '0'
125
- description:
111
+ description:
126
112
  email: andrew@chartkick.com
127
113
  executables: []
128
114
  extensions:
@@ -288,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
288
274
  licenses:
289
275
  - BSD-3-Clause
290
276
  metadata: {}
291
- post_install_message:
277
+ post_install_message:
292
278
  rdoc_options: []
293
279
  require_paths:
294
280
  - lib
@@ -303,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
303
289
  - !ruby/object:Gem::Version
304
290
  version: '0'
305
291
  requirements: []
306
- rubygems_version: 3.1.2
307
- signing_key:
292
+ rubygems_version: 3.1.4
293
+ signing_key:
308
294
  specification_version: 4
309
295
  summary: Deep learning for Ruby, powered by LibTorch
310
296
  test_files: []