torch-rb 0.4.0 → 0.5.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -348,16 +348,6 @@ void Init_ext()
348
348
  *[](Tensor& self) {
349
349
  return self.is_contiguous();
350
350
  })
351
- .define_method(
352
- "addcmul!",
353
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
354
- return self.addcmul_(tensor1, tensor2, value);
355
- })
356
- .define_method(
357
- "addcdiv!",
358
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
359
- return self.addcdiv_(tensor1, tensor2, value);
360
- })
361
351
  .define_method(
362
352
  "_requires_grad!",
363
353
  *[](Tensor& self, bool requires_grad) {
@@ -372,7 +362,7 @@ void Init_ext()
372
362
  .define_method(
373
363
  "grad=",
374
364
  *[](Tensor& self, torch::Tensor& grad) {
375
- self.grad() = grad;
365
+ self.mutable_grad() = grad;
376
366
  })
377
367
  .define_method(
378
368
  "_dtype",
@@ -609,7 +599,7 @@ void Init_ext()
609
599
  .define_method(
610
600
  "grad=",
611
601
  *[](Parameter& self, torch::Tensor& grad) {
612
- self.grad() = grad;
602
+ self.mutable_grad() = grad;
613
603
  });
614
604
 
615
605
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
@@ -91,7 +91,7 @@ struct RubyArgs {
91
91
  inline c10::optional<int64_t> toInt64Optional(int i);
92
92
  inline c10::optional<bool> toBoolOptional(int i);
93
93
  inline c10::optional<double> toDoubleOptional(int i);
94
- // inline c10::OptionalArray<double> doublelistOptional(int i);
94
+ inline c10::OptionalArray<double> doublelistOptional(int i);
95
95
  // inline at::Layout layout(int i);
96
96
  // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
97
97
  inline c10::optional<at::Layout> layoutOptional(int i);
@@ -105,7 +105,7 @@ struct RubyArgs {
105
105
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
106
106
  // inline at::QScheme toQScheme(int i);
107
107
  inline std::string string(int i);
108
- // inline c10::optional<std::string> stringOptional(int i);
108
+ inline c10::optional<std::string> stringOptional(int i);
109
109
  // inline PyObject* pyobject(int i);
110
110
  inline int64_t toInt64(int i);
111
111
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -249,6 +249,25 @@ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
249
249
  return toDouble(i);
250
250
  }
251
251
 
252
+ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
253
+ if (NIL_P(args[i])) return {};
254
+
255
+ VALUE arg = args[i];
256
+ auto size = RARRAY_LEN(arg);
257
+ std::vector<double> res(size);
258
+ for (idx = 0; idx < size; idx++) {
259
+ VALUE obj = rb_ary_entry(arg, idx);
260
+ if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
261
+ res[idx] = from_ruby<double>(obj);
262
+ } else {
263
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
264
+ signature.name.c_str(), signature.params[i].name.c_str(),
265
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
266
+ }
267
+ }
268
+ return res;
269
+ }
270
+
252
271
  inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
253
272
  if (NIL_P(args[i])) return c10::nullopt;
254
273
 
@@ -285,6 +304,11 @@ inline std::string RubyArgs::string(int i) {
285
304
  return from_ruby<std::string>(args[i]);
286
305
  }
287
306
 
307
+ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
308
+ if (!args[i]) return c10::nullopt;
309
+ return from_ruby<std::string>(args[i]);
310
+ }
311
+
288
312
  inline int64_t RubyArgs::toInt64(int i) {
289
313
  if (NIL_P(args[i])) return signature.params[i].default_int;
290
314
  return from_ruby<int64_t>(args[i]);
@@ -19,6 +19,7 @@ using torch::TensorOptions;
19
19
  using torch::Layout;
20
20
  using torch::MemoryFormat;
21
21
  using torch::IntArrayRef;
22
+ using torch::ArrayRef;
22
23
  using torch::TensorList;
23
24
  using torch::Storage;
24
25
 
@@ -90,3 +90,10 @@ inline Object wrap(torch::TensorList x) {
90
90
  }
91
91
  return Object(a);
92
92
  }
93
+
94
+ inline Object wrap(std::tuple<double, double> x) {
95
+ Array a;
96
+ a.push(to_ruby<double>(std::get<0>(x)));
97
+ a.push(to_ruby<double>(std::get<1>(x)));
98
+ return Object(a);
99
+ }
@@ -261,6 +261,8 @@ module Torch
261
261
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
262
262
  elsif args.size == 1 && args.first.is_a?(Array)
263
263
  Torch.tensor(args.first, dtype: dtype, device: device)
264
+ elsif args.size == 0
265
+ Torch.empty(0, dtype: dtype, device: device)
264
266
  else
265
267
  Torch.empty(*args, dtype: dtype, device: device)
266
268
  end
@@ -434,7 +436,8 @@ module Torch
434
436
  zeros(input.size, **like_options(input, options))
435
437
  end
436
438
 
437
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
439
+ # center option
440
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
438
441
  if center
439
442
  signal_dim = input.dim
440
443
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -442,12 +445,7 @@ module Torch
442
445
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
443
446
  input = input.view(input.shape[-signal_dim..-1])
444
447
  end
445
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
446
- end
447
-
448
- def clamp(tensor, min, max)
449
- tensor = _clamp_min(tensor, min)
450
- _clamp_max(tensor, max)
448
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
451
449
  end
452
450
 
453
451
  private
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
48
  p.data.add!(delta, alpha: -group[:lr])
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end
@@ -58,7 +58,7 @@ module Torch
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
60
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,7 +42,7 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -59,7 +59,7 @@ module Torch
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
61
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,7 +36,7 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end
@@ -52,7 +52,7 @@ module Torch
52
52
  grad[sign.eq(etaminus)] = 0
53
53
 
54
54
  # update parameters
55
- p.data.addcmul!(-1, grad.sign, step_size)
55
+ p.data.addcmul!(grad.sign, step_size, value: -1)
56
56
 
57
57
  state[:prev].copy!(grad)
58
58
  end
@@ -36,14 +36,14 @@ module Torch
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
39
- if !param_state.key(:momentum_buffer)
39
+ if !param_state.key?(:momentum_buffer)
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
43
  buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
- d_p = d_p.add(momentum, buf)
46
+ d_p = d_p.add(buf, alpha: momentum)
47
47
  else
48
48
  d_p = buf
49
49
  end
@@ -174,5 +174,10 @@ module Torch
174
174
  return _random!(0, *args) if args.size == 1
175
175
  _random!(*args)
176
176
  end
177
+
178
+ # center option
179
+ def stft(*args)
180
+ Torch.stft(*args)
181
+ end
177
182
  end
178
183
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.4.0"
2
+ VERSION = "0.5.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.5.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-27 00:00:00.000000000 Z
11
+ date: 2020-10-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,21 +108,7 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
- - !ruby/object:Gem::Dependency
112
- name: magro
113
- requirement: !ruby/object:Gem::Requirement
114
- requirements:
115
- - - ">="
116
- - !ruby/object:Gem::Version
117
- version: '0'
118
- type: :development
119
- prerelease: false
120
- version_requirements: !ruby/object:Gem::Requirement
121
- requirements:
122
- - - ">="
123
- - !ruby/object:Gem::Version
124
- version: '0'
125
- description:
111
+ description:
126
112
  email: andrew@chartkick.com
127
113
  executables: []
128
114
  extensions:
@@ -288,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
288
274
  licenses:
289
275
  - BSD-3-Clause
290
276
  metadata: {}
291
- post_install_message:
277
+ post_install_message:
292
278
  rdoc_options: []
293
279
  require_paths:
294
280
  - lib
@@ -303,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
303
289
  - !ruby/object:Gem::Version
304
290
  version: '0'
305
291
  requirements: []
306
- rubygems_version: 3.1.2
307
- signing_key:
292
+ rubygems_version: 3.1.4
293
+ signing_key:
308
294
  specification_version: 4
309
295
  summary: Deep learning for Ruby, powered by LibTorch
310
296
  test_files: []