torch-rb 0.4.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,8 +44,13 @@ std::vector<TensorIndex> index_vector(Array a) {
44
44
  if (obj.is_instance_of(rb_cInteger)) {
45
45
  indices.push_back(from_ruby<int64_t>(obj));
46
46
  } else if (obj.is_instance_of(rb_cRange)) {
47
- torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
48
- torch::optional<int64_t> stop_index = -1;
47
+ torch::optional<int64_t> start_index = torch::nullopt;
48
+ torch::optional<int64_t> stop_index = torch::nullopt;
49
+
50
+ Object begin = obj.call("begin");
51
+ if (!begin.is_nil()) {
52
+ start_index = from_ruby<int64_t>(begin);
53
+ }
49
54
 
50
55
  Object end = obj.call("end");
51
56
  if (!end.is_nil()) {
@@ -53,12 +58,14 @@ std::vector<TensorIndex> index_vector(Array a) {
53
58
  }
54
59
 
55
60
  Object exclude_end = obj.call("exclude_end?");
56
- if (!exclude_end) {
61
+ if (stop_index.has_value() && !exclude_end) {
57
62
  if (stop_index.value() == -1) {
58
63
  stop_index = torch::nullopt;
59
64
  } else {
60
65
  stop_index = stop_index.value() + 1;
61
66
  }
67
+ } else if (!stop_index.has_value() && exclude_end) {
68
+ stop_index = -1;
62
69
  }
63
70
 
64
71
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
@@ -348,16 +355,6 @@ void Init_ext()
348
355
  *[](Tensor& self) {
349
356
  return self.is_contiguous();
350
357
  })
351
- .define_method(
352
- "addcmul!",
353
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
354
- return self.addcmul_(tensor1, tensor2, value);
355
- })
356
- .define_method(
357
- "addcdiv!",
358
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
359
- return self.addcdiv_(tensor1, tensor2, value);
360
- })
361
358
  .define_method(
362
359
  "_requires_grad!",
363
360
  *[](Tensor& self, bool requires_grad) {
@@ -372,7 +369,7 @@ void Init_ext()
372
369
  .define_method(
373
370
  "grad=",
374
371
  *[](Tensor& self, torch::Tensor& grad) {
375
- self.grad() = grad;
372
+ self.mutable_grad() = grad;
376
373
  })
377
374
  .define_method(
378
375
  "_dtype",
@@ -609,7 +606,7 @@ void Init_ext()
609
606
  .define_method(
610
607
  "grad=",
611
608
  *[](Parameter& self, torch::Tensor& grad) {
612
- self.grad() = grad;
609
+ self.mutable_grad() = grad;
613
610
  });
614
611
 
615
612
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
@@ -628,5 +625,7 @@ void Init_ext()
628
625
  Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
629
626
  .add_handler<torch::Error>(handle_error)
630
627
  .define_singleton_method("available?", &torch::cuda::is_available)
631
- .define_singleton_method("device_count", &torch::cuda::device_count);
628
+ .define_singleton_method("device_count", &torch::cuda::device_count)
629
+ .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
630
+ .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
632
631
  }
@@ -91,7 +91,7 @@ struct RubyArgs {
91
91
  inline c10::optional<int64_t> toInt64Optional(int i);
92
92
  inline c10::optional<bool> toBoolOptional(int i);
93
93
  inline c10::optional<double> toDoubleOptional(int i);
94
- // inline c10::OptionalArray<double> doublelistOptional(int i);
94
+ inline c10::OptionalArray<double> doublelistOptional(int i);
95
95
  // inline at::Layout layout(int i);
96
96
  // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
97
97
  inline c10::optional<at::Layout> layoutOptional(int i);
@@ -105,7 +105,7 @@ struct RubyArgs {
105
105
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
106
106
  // inline at::QScheme toQScheme(int i);
107
107
  inline std::string string(int i);
108
- // inline c10::optional<std::string> stringOptional(int i);
108
+ inline c10::optional<std::string> stringOptional(int i);
109
109
  // inline PyObject* pyobject(int i);
110
110
  inline int64_t toInt64(int i);
111
111
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -249,6 +249,25 @@ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
249
249
  return toDouble(i);
250
250
  }
251
251
 
252
+ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
253
+ if (NIL_P(args[i])) return {};
254
+
255
+ VALUE arg = args[i];
256
+ auto size = RARRAY_LEN(arg);
257
+ std::vector<double> res(size);
258
+ for (idx = 0; idx < size; idx++) {
259
+ VALUE obj = rb_ary_entry(arg, idx);
260
+ if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
261
+ res[idx] = from_ruby<double>(obj);
262
+ } else {
263
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
264
+ signature.name.c_str(), signature.params[i].name.c_str(),
265
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
266
+ }
267
+ }
268
+ return res;
269
+ }
270
+
252
271
  inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
253
272
  if (NIL_P(args[i])) return c10::nullopt;
254
273
 
@@ -285,6 +304,11 @@ inline std::string RubyArgs::string(int i) {
285
304
  return from_ruby<std::string>(args[i]);
286
305
  }
287
306
 
307
+ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
308
+ if (!args[i]) return c10::nullopt;
309
+ return from_ruby<std::string>(args[i]);
310
+ }
311
+
288
312
  inline int64_t RubyArgs::toInt64(int i) {
289
313
  if (NIL_P(args[i])) return signature.params[i].default_int;
290
314
  return from_ruby<int64_t>(args[i]);
@@ -19,6 +19,7 @@ using torch::TensorOptions;
19
19
  using torch::Layout;
20
20
  using torch::MemoryFormat;
21
21
  using torch::IntArrayRef;
22
+ using torch::ArrayRef;
22
23
  using torch::TensorList;
23
24
  using torch::Storage;
24
25
 
@@ -90,3 +90,10 @@ inline Object wrap(torch::TensorList x) {
90
90
  }
91
91
  return Object(a);
92
92
  }
93
+
94
+ inline Object wrap(std::tuple<double, double> x) {
95
+ Array a;
96
+ a.push(to_ruby<double>(std::get<0>(x)));
97
+ a.push(to_ruby<double>(std::get<1>(x)));
98
+ return Object(a);
99
+ }
@@ -261,6 +261,8 @@ module Torch
261
261
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
262
262
  elsif args.size == 1 && args.first.is_a?(Array)
263
263
  Torch.tensor(args.first, dtype: dtype, device: device)
264
+ elsif args.size == 0
265
+ Torch.empty(0, dtype: dtype, device: device)
264
266
  else
265
267
  Torch.empty(*args, dtype: dtype, device: device)
266
268
  end
@@ -434,7 +436,8 @@ module Torch
434
436
  zeros(input.size, **like_options(input, options))
435
437
  end
436
438
 
437
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
439
+ # center option
440
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
438
441
  if center
439
442
  signal_dim = input.dim
440
443
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -442,12 +445,7 @@ module Torch
442
445
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
443
446
  input = input.view(input.shape[-signal_dim..-1])
444
447
  end
445
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
446
- end
447
-
448
- def clamp(tensor, min, max)
449
- tensor = _clamp_min(tensor, min)
450
- _clamp_max(tensor, max)
448
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
451
449
  end
452
450
 
453
451
  private
@@ -113,35 +113,53 @@ module Torch
113
113
  forward(*input, **kwargs)
114
114
  end
115
115
 
116
- def state_dict(destination: nil)
116
+ def state_dict(destination: nil, prefix: "")
117
117
  destination ||= {}
118
- named_parameters.each do |k, v|
119
- destination[k] = v
118
+ save_to_state_dict(destination, prefix: prefix)
119
+
120
+ named_children.each do |name, mod|
121
+ next unless mod
122
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
120
123
  end
121
124
  destination
122
125
  end
123
126
 
124
- # TODO add strict option
125
- # TODO match PyTorch behavior
126
- def load_state_dict(state_dict)
127
- state_dict.each do |k, input_param|
128
- k1, k2 = k.split(".", 2)
129
- mod = named_modules[k1]
130
- if mod.is_a?(Module)
131
- param = mod.named_parameters[k2]
132
- if param.is_a?(Parameter)
133
- Torch.no_grad do
134
- param.copy!(input_param)
135
- end
136
- else
137
- raise Error, "Unknown parameter: #{k1}"
138
- end
139
- else
140
- raise Error, "Unknown module: #{k1}"
127
+ def load_state_dict(state_dict, strict: true)
128
+ # TODO support strict: false
129
+ raise "strict: false not implemented yet" unless strict
130
+
131
+ missing_keys = []
132
+ unexpected_keys = []
133
+ error_msgs = []
134
+
135
+ # TODO handle metadata
136
+
137
+ _load = lambda do |mod, prefix = ""|
138
+ # TODO handle metadata
139
+ local_metadata = {}
140
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
141
+ mod.named_children.each do |name, child|
142
+ _load.call(child, prefix + name + ".") unless child.nil?
141
143
  end
142
144
  end
143
145
 
144
- # TODO return missing keys and unexpected keys
146
+ _load.call(self)
147
+
148
+ if strict
149
+ if unexpected_keys.any?
150
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
151
+ end
152
+
153
+ if missing_keys.any?
154
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
155
+ end
156
+ end
157
+
158
+ if error_msgs.any?
159
+ # just show first error
160
+ raise Error, error_msgs[0]
161
+ end
162
+
145
163
  nil
146
164
  end
147
165
 
@@ -300,6 +318,68 @@ module Torch
300
318
  def dict
301
319
  instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
302
320
  end
321
+
322
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
323
+ # TODO add hooks
324
+
325
+ # TODO handle non-persistent buffers
326
+ persistent_buffers = named_buffers
327
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
328
+ local_state = local_name_params.select { |_, v| !v.nil? }
329
+
330
+ local_state.each do |name, param|
331
+ key = prefix + name
332
+ if state_dict.key?(key)
333
+ input_param = state_dict[key]
334
+
335
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
336
+ if param.shape.length == 0 && input_param.shape.length == 1
337
+ input_param = input_param[0]
338
+ end
339
+
340
+ if input_param.shape != param.shape
341
+ # local shape should match the one in checkpoint
342
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
343
+ "the shape in current model is #{param.shape}."
344
+ next
345
+ end
346
+
347
+ begin
348
+ Torch.no_grad do
349
+ param.copy!(input_param)
350
+ end
351
+ rescue => e
352
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
353
+ "whose dimensions in the model are #{param.size} and " +
354
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
355
+ "an exception occurred: #{e.inspect}"
356
+ end
357
+ elsif strict
358
+ missing_keys << key
359
+ end
360
+ end
361
+
362
+ if strict
363
+ state_dict.each_key do |key|
364
+ if key.start_with?(prefix)
365
+ input_name = key[prefix.length..-1]
366
+ input_name = input_name.split(".", 2)[0]
367
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
368
+ unexpected_keys << key
369
+ end
370
+ end
371
+ end
372
+ end
373
+ end
374
+
375
+ def save_to_state_dict(destination, prefix: "")
376
+ named_parameters(recurse: false).each do |k, v|
377
+ destination[prefix + k] = v
378
+ end
379
+ named_buffers.each do |k, v|
380
+ destination[prefix + k] = v
381
+ end
382
+ end
303
383
  end
304
384
  end
305
385
  end
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
48
  p.data.add!(delta, alpha: -group[:lr])
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end
@@ -58,7 +58,7 @@ module Torch
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
60
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,7 +42,7 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -59,7 +59,7 @@ module Torch
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
61
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,7 +36,7 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end